"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a38dd795120e1884e3396d41bf44e44fd9b1eba0"
Unverified Commit 8e981db9 authored by Tianqi Zhang (张天启)'s avatar Tianqi Zhang (张天启) Committed by GitHub
Browse files

[Example] Add Graph Cross Net (GXN) example for pytorch backend (#2559)



* add sagpool example for pytorch backend

* polish sagpool example for pytorch backend

* [Example] SAGPool: use std variance

* [Example] SAGPool: change to std

* add sagpool example to index page

* add graph property prediction tag to sagpool

* [Example] add graph classification example HGP-SL

* [Example] fix sagpool

* fix bug

* [Example] change tab to space in README of hgp-sl

* remove redundant files

* remote redundant network

* [Example]: change link from code to doc in HGP-SL

* [Example] in HGP-SL, change to meaningful name

* [Example] Fix path mistake for 'hardgat'

* [Bug Fix] Fix undefined var bug in LegacyTUDataset

* upt

* [Bug Fix] Fix cache file name bug in TUDataset

* [Example] Add GXN example for pytorch backend

* modify readme

* add more exp result
Co-authored-by: default avatarzhangtianqi <tianqizh@amazon.com>
Co-authored-by: default avatarTong He <hetong007@gmail.com>
parent 9fc5eed6
...@@ -45,6 +45,7 @@ The folder contains example implementations of selected research papers related ...@@ -45,6 +45,7 @@ The folder contains example implementations of selected research papers related
| [GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation](#gnnfilm) | :heavy_check_mark: | | | | | | [GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation](#gnnfilm) | :heavy_check_mark: | | | | |
| [Hierarchical Graph Pooling with Structure Learning](#hgp-sl) | | | :heavy_check_mark: | | | | [Hierarchical Graph Pooling with Structure Learning](#hgp-sl) | | | :heavy_check_mark: | | |
| [Graph Representation Learning via Hard and Channel-Wise Attention Networks](#hardgat) |:heavy_check_mark: | | | | | | [Graph Representation Learning via Hard and Channel-Wise Attention Networks](#hardgat) |:heavy_check_mark: | | | | |
| [Graph Cross Networks with Vertex Infomax Pooling](#gxn) | | | :heavy_check_mark: | | |
| [Towards Deeper Graph Neural Networks](#dagnn) | :heavy_check_mark: | | | | | | [Towards Deeper Graph Neural Networks](#dagnn) | :heavy_check_mark: | | | | |
## 2020 ## 2020
...@@ -73,6 +74,10 @@ The folder contains example implementations of selected research papers related ...@@ -73,6 +74,10 @@ The folder contains example implementations of selected research papers related
- Example code: [Pytorch](../examples/pytorch/GNN-FiLM) - Example code: [Pytorch](../examples/pytorch/GNN-FiLM)
- Tags: multi-relational graphs, hypernetworks, GNN architectures - Tags: multi-relational graphs, hypernetworks, GNN architectures
- <a name="gxn"></a> Li, Maosen, et al. Graph Cross Networks with Vertex Infomax Pooling. [Paper link](https://arxiv.org/abs/2010.01804).
- Example code: [Pytorch](../examples/pytorch/gxn)
- Tags: pooling, graph classification
- <a name="dagnn"></a> Liu et al. Towards Deeper Graph Neural Networks. [Paper link](https://arxiv.org/abs/2007.09296). - <a name="dagnn"></a> Liu et al. Towards Deeper Graph Neural Networks. [Paper link](https://arxiv.org/abs/2007.09296).
- Example code: [Pytorch](../examples/pytorch/dagnn) - Example code: [Pytorch](../examples/pytorch/dagnn)
- Tags: over-smoothing, node classification - Tags: over-smoothing, node classification
......
# DGL Implementation of Graph Cross Networks with Vertex Infomax Pooling (NeurIPS 2020)
This DGL example implements the GNN model proposed in the paper [Graph Cross Networks with Vertex Infomax Pooling](https://arxiv.org/pdf/2010.01804.pdf).
The author's codes of implementation is in [here](https://github.com/limaosen0/GXN)
The graph dataset used in this example
---------------------------------------
The DGL's built-in LegacyTUDataset. This is a serial of graph kernel datasets for graph classification. We use 'DD', 'PROTEINS', 'ENZYMES', 'IMDB-BINARY', 'IMDB-MULTI' and 'COLLAB' in this GXN implementation. All these datasets are randomly splited to train and test set with ratio 0.9 and 0.1 (which is similar to the setting in the author's implementation).
NOTE: Follow the setting of the author's implementation, for 'DD' and 'PROTEINS', we use one-hot node label as input node features. For ENZYMES', 'IMDB-BINARY', 'IMDB-MULTI' and 'COLLAB', we use the concatenation of one-hot node label (if available) and one-hot node degree as input node features.
DD
- NumGraphs: 1178
- AvgNodesPerGraph: 284.32
- AvgEdgesPerGraph: 715.66
- NumFeats: 89
- NumClasses: 2
PROTEINS
- NumGraphs: 1113
- AvgNodesPerGraph: 39.06
- AvgEdgesPerGraph: 72.82
- NumFeats: 1
- NumClasses: 2
ENZYMES
- NumGraphs: 600
- AvgNodesPerGraph: 32.63
- AvgEdgesPerGraph: 62.14
- NumFeats: 18
- NumClasses: 6
IMDB-BINARY
- NumGraphs: 1000
- AvgNodesPerGraph: 19.77
- AvgEdgesPerGraph: 96.53
- NumFeats: -
- NumClasses: 2
IMDB-MULTI
- NumGraphs: 1500
- AvgNodesPerGraph: 13.00
- AvgEdgesPerGraph: 65.94
- NumFeats: -
- NumClasses: 3
COLLAB
- NumGraphs: 5000
- AvgNodesPerGraph: 74.49
- AvgEdgesPerGraph: 2457.78
- NumFeats: -
- NumClasses: 3
How to run example files
--------------------------------
If you want to reproduce the author's result, at the root directory of this example (gxn), run
```bash
bash scripts/run_gxn.sh ${dataset_name} ${device_id} ${num_trials} ${print_trainlog_every}
```
If you want to perform a early-stop version experiment, at the root directory of this example, run
```bash
bash scripts/run_gxn_early_stop.sh ${dataset_name} ${device_id} ${num_trials} ${print_trainlog_every}
```
where
- dataset_name: Dataset name used in this experiment. Could be DD', 'PROTEINS', 'ENZYMES', 'IMDB-BINARY', 'IMDB-MULTI' and 'COLLAB'.
- device_id: ID of computation device. -1 for pure CPU computation. For example if you only have single GPU, set this value to be 0.
- num_trials: How many times does the experiment conducted.
- print_training_log_every: Print training log every ? epochs. -1 for silent training.
NOTE: If your have problem when using 'IMDB-BINARY', 'IMDB-MULTI' and 'COLLAB', it could be caused by a bug in `LegacyTUDataset`/`TUDataset` in DGL (see [here](https://github.com/dmlc/dgl/pull/2543)). If your DGL version is less than or equal to 0.5.3 and you encounter problems like "undefined variable" (`LegacyTUDataset`) or "the argument `force_reload=False` does not work" (`TUDataset`), try:
- use `TUDataset` with `force_reload=True`
- delete dataset files
- change `degree_as_feature(dataset)` and `node_label_as_feature(dataset, mode=mode)` to `degree_as_feature(dataset, save=False)` and `node_label_as_feature(dataset, mode=mode, save=False)` in `main.py`.
Performance
-------------------------
**Accuracy**
**NOTE**: Different from our implementation, the author uses fixed dataset split. Thus there may be difference between our result and the author's result. **To compare our implementation with the author's, we follow the setting in the author's implementation that performs model-selection on testset**. We also try early-stop with patience equals to 1/5 of the total number of epochs for some datasets. The result of `Author's Code` in the table below are obtained using first-ford data as the test dataset.
| | DD | PROTEINS | ENZYMES | IMDB-BINARY | IMDB-MULTI | COLLAB |
| ------------------| ------------ | ----------- | ----------- | ----------- | ---------- | ---------- |
| Reported in Paper | 82.68(4.1 ) | 79.91(4.1) | 57.50(6.1) | 78.60(2.3) | 55.20(2.5) | 78.82(1.4) |
| Author's Code | 82.05 | 72.07 | 58.33 | 77.00 | 56.00 | 80.40 |
| DGL | 82.97(3.0) | 78.21(2.0) | 57.50(5.5) | 78.70(4.0) | 52.26(2.0) | 80.58(2.4) |
| DGL(early-stop) | 78.66(4.3) | 73.12(3.1) | 39.83(7.4) | 68.60(6.7) | 45.40(9.4) | 76.18(1.9) |
**Speed**
Device:
- CPU: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
- GPU: Tesla V100-SXM2 16GB
In seconds
| | DD | PROTEINS | ENZYMES | IMDB-BINARY | IMDB-MULTI | COLLAB(batch_size=64) | COLLAB(batch_size=20) |
| ------------- | ----- | -------- | ------- | ----------- | ---------- | --------------------- | --------------------- |
| Author's Code | 25.32 | 2.93 | 1.53 | 2.42 | 3.58 | 96.69 | 19.78 |
| DGL | 2.64 | 1.86 | 1.03 | 1.79 | 2.45 | 23.52 | 32.29 |
import os
import sys
import logging
import torch
import numpy as np
from dgl.data import LegacyTUDataset
import json
def _load_check_mark(path:str):
if os.path.exists(path):
with open(path, 'r') as f:
return json.load(f)
else:
return {}
def _save_check_mark(path:str, marks:dict):
with open(path, 'w') as f:
json.dump(marks, f)
def node_label_as_feature(dataset:LegacyTUDataset, mode="concat", save=True):
"""
Description
-----------
Add node labels to graph node features dict
Parameters
----------
dataset : LegacyTUDataset
The dataset object
concat : str, optional
How to add node label to the graph. Valid options are "add",
"replace" and "concat".
- "add": Directly add node_label to graph node feature dict.
- "concat": Concatenate "feat" and "node_label"
- "replace": Use "node_label" as "feat"
Default: :obj:`"concat"`
save : bool, optional
Save the result dataset.
Default: :obj:`True`
"""
# check if node label is not available
if not os.path.exists(dataset._file_path("node_labels")) or len(dataset) == 0:
logging.warning("No Node Label Data")
return dataset
# check if has cached value
check_mark_name = "node_label_as_feature"
check_mark_path = os.path.join(
dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash))
check_mark = _load_check_mark(check_mark_path)
if check_mark_name in check_mark \
and check_mark[check_mark_name] \
and not dataset._force_reload:
logging.warning("Using cached value in node_label_as_feature")
return dataset
logging.warning("Adding node labels into node features..., mode={}".format(mode))
# check if graph has "feat"
if "feat" not in dataset[0][0].ndata:
logging.warning("Dataset has no node feature 'feat'")
if mode.lower() == "concat":
mode = "replace"
# first read node labels
DS_node_labels = dataset._idx_from_zero(
np.loadtxt(dataset._file_path("node_labels"), dtype=int))
one_hot_node_labels = dataset._to_onehot(DS_node_labels)
# read graph idx
DS_indicator = dataset._idx_from_zero(
np.genfromtxt(dataset._file_path("graph_indicator"), dtype=int))
node_idx_list = []
for idx in range(np.max(DS_indicator) + 1):
node_idx = np.where(DS_indicator == idx)
node_idx_list.append(node_idx[0])
# add to node feature dict
for idx, g in zip(node_idx_list, dataset.graph_lists):
node_labels_tensor = torch.tensor(one_hot_node_labels[idx, :])
if mode.lower() == "concat":
g.ndata["feat"] = torch.cat(
(g.ndata["feat"], node_labels_tensor), dim=1)
elif mode.lower() == "add":
g.ndata["node_label"] = node_labels_tensor
else: # replace
g.ndata["feat"] = node_labels_tensor
if save:
check_mark[check_mark_name] = True
_save_check_mark(check_mark_path, check_mark)
dataset.save()
return dataset
def degree_as_feature(dataset:LegacyTUDataset, save=True):
"""
Description
-----------
Use node degree (in one-hot format) as node feature
Parameters
----------
dataset : LegacyTUDataset
The dataset object
save : bool, optional
Save the result dataset.
Default: :obj:`True`
"""
# first check if already have such feature
check_mark_name = "degree_as_feat"
feat_name = "feat"
check_mark_path = os.path.join(
dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash))
check_mark = _load_check_mark(check_mark_path)
if check_mark_name in check_mark \
and check_mark[check_mark_name] \
and not dataset._force_reload:
logging.warning("Using cached value in 'degree_as_feature'")
return dataset
logging.warning("Adding node degree into node features...")
min_degree = sys.maxsize
max_degree = 0
for i in range(len(dataset)):
degrees = dataset.graph_lists[i].in_degrees()
min_degree = min(min_degree, degrees.min().item())
max_degree = max(max_degree, degrees.max().item())
vec_len = max_degree - min_degree + 1
for i in range(len(dataset)):
num_nodes = dataset.graph_lists[i].num_nodes()
node_feat = torch.zeros((num_nodes, vec_len))
degrees = dataset.graph_lists[i].in_degrees()
node_feat[torch.arange(num_nodes), degrees - min_degree] = 1.
dataset.graph_lists[i].ndata[feat_name] = node_feat
if save:
check_mark[check_mark_name] = True
dataset.save()
_save_check_mark(check_mark_path, check_mark)
return dataset
from typing import Optional
import dgl
import torch
import torch.nn
from dgl import DGLGraph
from dgl.nn import GraphConv
from torch import Tensor
class GraphConvWithDropout(GraphConv):
"""
A GraphConv followed by a Dropout.
"""
def __init__(self, in_feats, out_feats, dropout=0.3, norm='both', weight=True,
bias=True, activation=None, allow_zero_in_degree=False):
super(GraphConvWithDropout, self).__init__(in_feats, out_feats,
norm, weight, bias,
activation,
allow_zero_in_degree)
self.dropout = torch.nn.Dropout(p=dropout)
def call(self, graph, feat, weight=None):
feat = self.dropout(feat)
return super(GraphConvWithDropout, self).call(graph, feat, weight)
class Discriminator(torch.nn.Module):
"""
Description
-----------
A discriminator used to let the network to discrimate
between positive (neighborhood of center node) and
negative (any neighborhood in graph) samplings.
Parameters
----------
feat_dim : int
The number of channels of node features.
"""
def __init__(self, feat_dim:int):
super(Discriminator, self).__init__()
self.affine = torch.nn.Bilinear(feat_dim, feat_dim, 1)
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.affine.weight)
torch.nn.init.zeros_(self.affine.bias)
def forward(self, h_x:Tensor, h_pos:Tensor,
h_neg:Tensor, bias_pos:Optional[Tensor]=None,
bias_neg:Optional[Tensor]=None):
"""
Parameters
----------
h_x : torch.Tensor
Node features, shape: :obj:`(num_nodes, feat_dim)`
h_pos : torch.Tensor
The node features of positive samples
It has the same shape as :obj:`h_x`
h_neg : torch.Tensor
The node features of negative samples
It has the same shape as :obj:`h_x`
bias_pos : torch.Tensor
Bias parameter vector for positive scores
shape: :obj:`(num_nodes)`
bias_neg : torch.Tensor
Bias parameter vector for negative scores
shape: :obj:`(num_nodes)`
Returns
-------
(torch.Tensor, torch.Tensor)
The output scores with shape (2 * num_nodes,), (num_nodes,)
"""
score_pos = self.affine(h_pos, h_x).squeeze()
score_neg = self.affine(h_neg, h_x).squeeze()
if bias_pos is not None:
score_pos = score_pos + bias_pos
if bias_neg is not None:
score_neg = score_neg + bias_neg
logits = torch.cat((score_pos, score_neg), 0)
return logits, score_pos
class DenseLayer(torch.nn.Module):
"""
Description
-----------
Dense layer with a linear layer and an activation function
"""
def __init__(self, in_dim:int, out_dim:int,
act:str="prelu", bias=True):
super(DenseLayer, self).__init__()
self.lin = torch.nn.Linear(in_dim, out_dim, bias=bias)
self.act_type = act.lower()
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.lin.weight)
if self.lin.bias is not None:
torch.nn.init.zeros_(self.lin.bias)
if self.act_type == "prelu":
self.act = torch.nn.PReLU()
else:
self.act = torch.relu
def forward(self, x):
x = self.lin(x)
return self.act(x)
class IndexSelect(torch.nn.Module):
"""
Description
-----------
The index selection layer used by VIPool
Parameters
----------
pool_ratio : float
The pooling ratio (for keeping nodes). For example,
if `pool_ratio=0.8`, 80\% nodes will be preserved.
hidden_dim : int
The number of channels in node features.
act : str, optional
The activation function type.
Default: :obj:`'prelu'`
dist : int, optional
DO NOT USE THIS PARAMETER
"""
def __init__(self, pool_ratio:float, hidden_dim:int,
act:str="prelu", dist:int=1):
super(IndexSelect, self).__init__()
self.pool_ratio = pool_ratio
self.dist = dist
self.dense = DenseLayer(hidden_dim, hidden_dim, act)
self.discriminator = Discriminator(hidden_dim)
self.gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
def forward(self, graph:DGLGraph, h_pos:Tensor,
h_neg:Tensor, bias_pos:Optional[Tensor]=None,
bias_neg:Optional[Tensor]=None):
"""
Description
-----------
Perform index selection
Parameters
----------
graph : dgl.DGLGraph
Input graph.
h_pos : torch.Tensor
The node features of positive samples
It has the same shape as :obj:`h_x`
h_neg : torch.Tensor
The node features of negative samples
It has the same shape as :obj:`h_x`
bias_pos : torch.Tensor
Bias parameter vector for positive scores
shape: :obj:`(num_nodes)`
bias_neg : torch.Tensor
Bias parameter vector for negative scores
shape: :obj:`(num_nodes)`
"""
# compute scores
h_pos = self.dense(h_pos)
h_neg = self.dense(h_neg)
embed = self.gcn(graph, h_pos)
h_center = torch.sigmoid(embed)
logit, logit_pos = self.discriminator(h_center, h_pos,
h_neg, bias_pos,
bias_neg)
scores = torch.sigmoid(logit_pos)
# sort scores
scores, idx = torch.sort(scores, descending=True)
# select top-k
num_nodes = graph.num_nodes()
num_select_nodes = int(self.pool_ratio * num_nodes)
size_list = [num_select_nodes, num_nodes - num_select_nodes]
select_scores, _ = torch.split(scores, size_list, dim=0)
select_idx, non_select_idx = torch.split(idx, size_list, dim=0)
return logit, select_scores, select_idx, non_select_idx, embed
class GraphPool(torch.nn.Module):
"""
Description
-----------
The pooling module for graph
Parameters
----------
hidden_dim : int
The number of channels of node features.
use_gcn : bool, optional
Whether use gcn in down sampling process.
default: :obj:`False`
"""
def __init__(self, hidden_dim:int, use_gcn=False):
super(GraphPool, self).__init__()
self.use_gcn = use_gcn
self.down_sample_gcn = GraphConvWithDropout(hidden_dim, hidden_dim) \
if use_gcn else None
def forward(self, graph:DGLGraph, feat:Tensor,
select_idx:Tensor, non_select_idx:Optional[Tensor]=None,
scores:Optional[Tensor]=None, pool_graph=False):
"""
Description
-----------
Perform graph pooling.
Parameters
----------
graph : dgl.DGLGraph
The input graph
feat : torch.Tensor
The input node feature
select_idx : torch.Tensor
The index in fine graph of node from
coarse graph, this is obtained from
previous graph pooling layers.
non_select_idx : torch.Tensor, optional
The index that not included in output graph.
default: :obj:`None`
scores : torch.Tensor, optional
Scores for nodes used for pooling and scaling.
default: :obj:`None`
pool_graph : bool, optional
Whether perform graph pooling on graph topology.
default: :obj:`False`
"""
if self.use_gcn:
feat = self.down_sample_gcn(graph, feat)
feat = feat[select_idx]
if scores is not None:
feat = feat * scores.unsqueeze(-1)
if pool_graph:
num_node_batch = graph.batch_num_nodes()
graph = dgl.node_subgraph(graph, select_idx)
graph.set_batch_num_nodes(num_node_batch)
return feat, graph
else:
return feat
class GraphUnpool(torch.nn.Module):
"""
Description
-----------
The unpooling module for graph
Parameters
----------
hidden_dim : int
The number of channels of node features.
"""
def __init__(self, hidden_dim:int):
super(GraphUnpool, self).__init__()
self.up_sample_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
def forward(self, graph:DGLGraph,
feat:Tensor, select_idx:Tensor):
"""
Description
-----------
Perform graph unpooling
Parameters
----------
graph : dgl.DGLGraph
The input graph
feat : torch.Tensor
The input node feature
select_idx : torch.Tensor
The index in fine graph of node from
coarse graph, this is obtained from
previous graph pooling layers.
"""
fine_feat = torch.zeros((graph.num_nodes(), feat.size(-1)),
device=feat.device)
fine_feat[select_idx] = feat
fine_feat = self.up_sample_gcn(graph, fine_feat)
return fine_feat
import json
import os
from datetime import datetime
from time import time
import dgl
import torch
import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from torch import Tensor
from torch.utils.data import random_split
from data_preprocess import degree_as_feature, node_label_as_feature
from networks import GraphClassifier
from utils import get_stats, parse_args
def compute_loss(cls_logits:Tensor, labels:Tensor,
logits_s1:Tensor, logits_s2:Tensor,
epoch:int, total_epochs:int, device:torch.device):
# classification loss
classify_loss = F.nll_loss(cls_logits, labels.to(device))
# loss for vertex infomax pooling
scale1, scale2 = logits_s1.size(0) // 2, logits_s2.size(0) // 2
s1_label_t, s1_label_f = torch.ones(scale1), torch.zeros(scale1)
s2_label_t, s2_label_f = torch.ones(scale2), torch.zeros(scale2)
s1_label = torch.cat((s1_label_t, s1_label_f), dim=0).to(device)
s2_label = torch.cat((s2_label_t, s2_label_f), dim=0).to(device)
pool_loss_s1 = F.binary_cross_entropy_with_logits(logits_s1, s1_label)
pool_loss_s2 = F.binary_cross_entropy_with_logits(logits_s2, s2_label)
pool_loss = (pool_loss_s1 + pool_loss_s2) / 2
loss = classify_loss + (2 - epoch / total_epochs) * pool_loss
return loss
def train(model:torch.nn.Module, optimizer, trainloader,
device, curr_epoch, total_epochs):
model.train()
total_loss = 0.
num_batches = len(trainloader)
for batch in trainloader:
optimizer.zero_grad()
batch_graphs, batch_labels = batch
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device)
out, l1, l2 = model(batch_graphs,
batch_graphs.ndata["feat"])
loss = compute_loss(out, batch_labels, l1, l2,
curr_epoch, total_epochs, device)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / num_batches
@torch.no_grad()
def test(model:torch.nn.Module, loader, device):
model.eval()
correct = 0.
num_graphs = 0
for batch in loader:
batch_graphs, batch_labels = batch
num_graphs += batch_labels.size(0)
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device)
out, _, _ = model(batch_graphs, batch_graphs.ndata["feat"])
pred = out.argmax(dim=1)
correct += pred.eq(batch_labels).sum().item()
return correct / num_graphs
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
dataset = LegacyTUDataset(args.dataset, raw_dir=args.dataset_path)
# add self loop. We add self loop for each graph here since the function "add_self_loop" does not
# support batch graph.
for i in range(len(dataset)):
dataset.graph_lists[i] = dgl.remove_self_loop(dataset.graph_lists[i])
dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i])
# preprocess: use node degree/label as node feature
if args.degree_as_feature:
dataset = degree_as_feature(dataset)
mode = "concat"
else:
mode = "replace"
dataset = node_label_as_feature(dataset, mode=mode)
num_training = int(len(dataset) * 0.9)
num_test = len(dataset) - num_training
train_set, test_set = random_split(dataset, [num_training, num_test])
train_loader = GraphDataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
test_loader = GraphDataLoader(test_set, batch_size=args.batch_size, num_workers=1)
device = torch.device(args.device)
# Step 2: Create model =================================================================== #
num_feature, num_classes, _ = dataset.statistics()
args.in_dim = int(num_feature)
args.out_dim = int(num_classes)
args.edge_feat_dim = 0 # No edge feature in datasets that we use.
model = GraphClassifier(args).to(device)
# Step 3: Create training components ===================================================== #
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True, weight_decay=args.weight_decay)
# Step 4: training epoches =============================================================== #
best_test_acc = 0.0
best_epoch = -1
train_times = []
for e in range(args.epochs):
s_time = time()
train_loss = train(model, optimizer, train_loader, device,
e, args.epochs)
train_times.append(time() - s_time)
test_acc = test(model, test_loader, device)
if test_acc > best_test_acc:
best_test_acc = test_acc
best_epoch = e + 1
if (e + 1) % args.print_every == 0:
log_format = "Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}"
print(log_format.format(e + 1, train_loss, test_acc, best_test_acc))
print("Best Epoch {}, final test acc {:.4f}".format(best_epoch, best_test_acc))
return best_test_acc, sum(train_times) / len(train_times)
if __name__ == "__main__":
args = parse_args()
res = []
train_times = []
for i in range(args.num_trials):
print("Trial {}/{}".format(i + 1, args.num_trials))
acc, train_time = main(args)
# acc, train_time = 0, 0
res.append(acc)
train_times.append(train_time)
mean, err_bd = get_stats(res, conf_interval=False)
print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))
out_dict = {"hyper-parameters": vars(args),
"result_date": str(datetime.now()),
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times)),
"details": res}
with open(os.path.join(args.output_path, "{}.log".format(args.dataset)), "w") as f:
json.dump(out_dict, f, sort_keys=True, indent=4)
import json
import os
from time import time
from datetime import datetime
import dgl
from dgl.data import LegacyTUDataset, TUDataset
from torch.utils.data import random_split
import torch
from torch import Tensor
import torch.nn.functional as F
from dataloader import GraphDataLoader
from networks import GraphClassifier
from utils import get_stats, parse_args
from data_preprocess import degree_as_feature, node_label_as_feature
def compute_loss(cls_logits:Tensor, labels:Tensor,
logits_s1:Tensor, logits_s2:Tensor,
epoch:int, total_epochs:int, device:torch.device):
# classification loss
classify_loss = F.nll_loss(cls_logits, labels.to(device))
# loss for vertex infomax pooling
scale1, scale2 = logits_s1.size(0) // 2, logits_s2.size(0) // 2
s1_label_t, s1_label_f = torch.ones(scale1), torch.zeros(scale1)
s2_label_t, s2_label_f = torch.ones(scale2), torch.zeros(scale2)
s1_label = torch.cat((s1_label_t, s1_label_f), dim=0).to(device)
s2_label = torch.cat((s2_label_t, s2_label_f), dim=0).to(device)
pool_loss_s1 = F.binary_cross_entropy_with_logits(logits_s1, s1_label)
pool_loss_s2 = F.binary_cross_entropy_with_logits(logits_s2, s2_label)
pool_loss = (pool_loss_s1 + pool_loss_s2) / 2
loss = classify_loss + (2 - epoch / total_epochs) * pool_loss
return loss
def train(model:torch.nn.Module, optimizer, trainloader,
device, curr_epoch, total_epochs):
model.train()
total_loss = 0.
num_batches = len(trainloader)
for batch in trainloader:
optimizer.zero_grad()
batch_graphs, batch_labels = batch
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device)
out, l1, l2 = model(batch_graphs,
batch_graphs.ndata["feat"])
loss = compute_loss(out, batch_labels, l1, l2,
curr_epoch, total_epochs, device)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / num_batches
@torch.no_grad()
def test(model:torch.nn.Module, loader, device):
model.eval()
correct = 0.
num_graphs = 0
for batch in loader:
batch_graphs, batch_labels = batch
num_graphs += batch_labels.size(0)
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device)
out, _, _ = model(batch_graphs, batch_graphs.ndata["feat"])
pred = out.argmax(dim=1)
correct += pred.eq(batch_labels).sum().item()
return correct / num_graphs
@torch.no_grad()
def validate(model:torch.nn.Module, loader, device,
curr_epoch, total_epochs):
model.eval()
tt_loss = 0.
correct = 0.
num_graphs = 0
num_batchs = len(loader)
for batch in loader:
batch_graphs, batch_labels = batch
num_graphs += batch_labels.size(0)
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device)
out, l1, l2 = model(batch_graphs, batch_graphs.ndata["feat"])
tt_loss += compute_loss(out, batch_labels, l1, l2,
curr_epoch, total_epochs, device).item()
pred = out.argmax(dim=1)
correct += pred.eq(batch_labels).sum().item()
return correct / num_graphs, tt_loss / num_batchs
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
dataset = LegacyTUDataset(args.dataset, raw_dir=args.dataset_path)
# add self loop. We add self loop for each graph here since the function "add_self_loop" does not
# support batch graph.
for i in range(len(dataset)):
dataset.graph_lists[i] = dgl.remove_self_loop(dataset.graph_lists[i])
dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i])
# use degree as node feature
if args.degree_as_feature:
dataset = degree_as_feature(dataset)
mode = "concat"
else:
mode = "replace"
dataset = node_label_as_feature(dataset, mode=mode)
num_training = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - num_training - num_val
train_set, val_set, test_set = random_split(dataset, [num_training, num_val, num_test])
train_loader = GraphDataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
val_loader = GraphDataLoader(val_set, batch_size=args.batch_size, num_workers=1)
test_loader = GraphDataLoader(test_set, batch_size=args.batch_size, num_workers=1)
device = torch.device(args.device)
# Step 2: Create model =================================================================== #
num_feature, num_classes, _ = dataset.statistics()
args.in_dim = int(num_feature)
args.out_dim = int(num_classes)
args.edge_feat_dim = 0 # No edge feature in datasets that we use.
model = GraphClassifier(args).to(device)
# Step 3: Create training components ===================================================== #
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True, weight_decay=args.weight_decay)
# Step 4: training epoches =============================================================== #
best_test_acc = 0.0
best_epoch = -1
train_times = []
bad_count = 0
best_val_loss = float("inf")
for e in range(args.epochs):
s_time = time()
train_loss = train(model, optimizer, train_loader, device,
e, args.epochs)
train_times.append(time() - s_time)
_, val_loss = validate(model, val_loader, device, e, args.epochs)
test_acc = test(model, test_loader, device)
if best_val_loss > val_loss:
best_val_loss = val_loss
best_epoch = e
bad_count = 0
best_test_acc = test_acc
else:
bad_count += 1
if bad_count > args.patience:
break
if (e + 1) % args.print_every == 0:
log_format = "Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}"
print(log_format.format(e + 1, train_loss, test_acc, best_test_acc))
print("Best Epoch {}, final test acc {:.4f}".format(best_epoch, best_test_acc))
return best_test_acc, sum(train_times) / len(train_times)
if __name__ == "__main__":
args = parse_args()
res = []
train_times = []
for i in range(args.num_trials):
print("Trial {}/{}".format(i + 1, args.num_trials))
acc, train_time = main(args)
# acc, train_time = 0, 0
res.append(acc)
train_times.append(train_time)
mean, err_bd = get_stats(res, conf_interval=False)
print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))
out_dict = {"hyper-parameters": vars(args),
"result_date": str(datetime.now()),
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times)),
"details": res}
with open(os.path.join(args.output_path, "{}.log".format(args.dataset)), "w") as f:
json.dump(out_dict, f, sort_keys=True, indent=4)
from typing import List, Tuple, Union
from layers import *
import torch
import torch.nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.pytorch.glob import SortPooling
class GraphCrossModule(torch.nn.Module):
"""
Description
-----------
The Graph Cross Module used by Graph Cross Networks.
This module only contains graph cross layers.
Parameters
----------
pool_ratios : Union[float, List[float]]
The pooling ratios (for keeping nodes) for each layer.
For example, if `pool_ratio=0.8`, 80\% nodes will be preserved.
If a single float number is given, all pooling layers will have the
same pooling ratio.
in_dim : int
The number of input node feature channels.
out_dim : int
The number of output node feature channels.
hidden_dim : int
The number of hidden node feature channels.
cross_weight : float, optional
The weight parameter used in graph cross layers
Default: :obj:`1.0`
fuse_weight : float, optional
The weight parameter used at the end of GXN for channel fusion.
Default: :obj:`1.0`
"""
def __init__(self, pool_ratios:Union[float, List[float]], in_dim:int,
out_dim:int, hidden_dim:int, cross_weight:float=1.,
fuse_weight:float=1., dist:int=1, num_cross_layers:int=2):
super(GraphCrossModule, self).__init__()
if isinstance(pool_ratios, float):
pool_ratios = (pool_ratios, pool_ratios)
self.cross_weight = cross_weight
self.fuse_weight = fuse_weight
self.num_cross_layers = num_cross_layers
# build network
self.start_gcn_scale1 = GraphConvWithDropout(in_dim, hidden_dim)
self.start_gcn_scale2 = GraphConvWithDropout(hidden_dim, hidden_dim)
self.end_gcn = GraphConvWithDropout(2 * hidden_dim, out_dim)
self.index_select_scale1 = IndexSelect(pool_ratios[0], hidden_dim, act="prelu", dist=dist)
self.index_select_scale2 = IndexSelect(pool_ratios[1], hidden_dim, act="prelu", dist=dist)
self.start_pool_s12 = GraphPool(hidden_dim)
self.start_pool_s23 = GraphPool(hidden_dim)
self.end_unpool_s21 = GraphUnpool(hidden_dim)
self.end_unpool_s32 = GraphUnpool(hidden_dim)
self.s1_l1_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
self.s1_l2_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
self.s1_l3_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
self.s2_l1_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
self.s2_l2_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
self.s2_l3_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
self.s3_l1_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
self.s3_l2_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
self.s3_l3_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
if num_cross_layers >= 1:
self.pool_s12_1 = GraphPool(hidden_dim, use_gcn=True)
self.unpool_s21_1 = GraphUnpool(hidden_dim)
self.pool_s23_1 = GraphPool(hidden_dim, use_gcn=True)
self.unpool_s32_1 = GraphUnpool(hidden_dim)
if num_cross_layers >= 2:
self.pool_s12_2 = GraphPool(hidden_dim, use_gcn=True)
self.unpool_s21_2 = GraphUnpool(hidden_dim)
self.pool_s23_2 = GraphPool(hidden_dim, use_gcn=True)
self.unpool_s32_2 = GraphUnpool(hidden_dim)
def forward(self, graph, feat):
# start of scale-1
graph_scale1 = graph
feat_scale1 = self.start_gcn_scale1(graph_scale1, feat)
feat_origin = feat_scale1
feat_scale1_neg = feat_scale1[torch.randperm(feat_scale1.size(0))] # negative samples
logit_s1, scores_s1, select_idx_s1, non_select_idx_s1, feat_down_s1 = \
self.index_select_scale1(graph_scale1, feat_scale1, feat_scale1_neg)
feat_scale2, graph_scale2 = self.start_pool_s12(graph_scale1, feat_scale1,
select_idx_s1, non_select_idx_s1,
scores_s1, pool_graph=True)
# start of scale-2
feat_scale2 = self.start_gcn_scale2(graph_scale2, feat_scale2)
feat_scale2_neg = feat_scale2[torch.randperm(feat_scale2.size(0))] # negative samples
logit_s2, scores_s2, select_idx_s2, non_select_idx_s2, feat_down_s2 = \
self.index_select_scale2(graph_scale2, feat_scale2, feat_scale2_neg)
feat_scale3, graph_scale3 = self.start_pool_s23(graph_scale2, feat_scale2,
select_idx_s2, non_select_idx_s2,
scores_s2, pool_graph=True)
# layer-1
res_s1_0, res_s2_0, res_s3_0 = feat_scale1, feat_scale2, feat_scale3
feat_scale1 = F.relu(self.s1_l1_gcn(graph_scale1, feat_scale1))
feat_scale2 = F.relu(self.s2_l1_gcn(graph_scale2, feat_scale2))
feat_scale3 = F.relu(self.s3_l1_gcn(graph_scale3, feat_scale3))
if self.num_cross_layers >= 1:
feat_s12_fu = self.pool_s12_1(graph_scale1, feat_scale1,
select_idx_s1, non_select_idx_s1,
scores_s1)
feat_s21_fu = self.unpool_s21_1(graph_scale1, feat_scale2, select_idx_s1)
feat_s23_fu = self.pool_s23_1(graph_scale2, feat_scale2,
select_idx_s2, non_select_idx_s2,
scores_s2)
feat_s32_fu = self.unpool_s32_1(graph_scale2, feat_scale3, select_idx_s2)
feat_scale1 = feat_scale1 + self.cross_weight * feat_s21_fu + res_s1_0
feat_scale2 = feat_scale2 + self.cross_weight * (feat_s12_fu + feat_s32_fu) / 2 + res_s2_0
feat_scale3 = feat_scale3 + self.cross_weight * feat_s23_fu + res_s3_0
# layer-2
feat_scale1 = F.relu(self.s1_l2_gcn(graph_scale1, feat_scale1))
feat_scale2 = F.relu(self.s2_l2_gcn(graph_scale2, feat_scale2))
feat_scale3 = F.relu(self.s3_l2_gcn(graph_scale3, feat_scale3))
if self.num_cross_layers >= 2:
feat_s12_fu = self.pool_s12_2(graph_scale1, feat_scale1,
select_idx_s1, non_select_idx_s1,
scores_s1)
feat_s21_fu = self.unpool_s21_2(graph_scale1, feat_scale2, select_idx_s1)
feat_s23_fu = self.pool_s23_2(graph_scale2, feat_scale2,
select_idx_s2, non_select_idx_s2,
scores_s2)
feat_s32_fu = self.unpool_s32_2(graph_scale2, feat_scale3, select_idx_s2)
cross_weight = self.cross_weight * 0.05
feat_scale1 = feat_scale1 + cross_weight * feat_s21_fu
feat_scale2 = feat_scale2 + cross_weight * (feat_s12_fu + feat_s32_fu) / 2
feat_scale3 = feat_scale3 + cross_weight * feat_s23_fu
# layer-3
feat_scale1 = F.relu(self.s1_l3_gcn(graph_scale1, feat_scale1))
feat_scale2 = F.relu(self.s2_l3_gcn(graph_scale2, feat_scale2))
feat_scale3 = F.relu(self.s3_l3_gcn(graph_scale3, feat_scale3))
# final layers
feat_s3_out = self.end_unpool_s32(graph_scale2, feat_scale3, select_idx_s2) + feat_down_s2
feat_s2_out = self.end_unpool_s21(graph_scale1, feat_scale2 + feat_s3_out, select_idx_s1)
feat_agg = feat_scale1 + self.fuse_weight * feat_s2_out + self.fuse_weight * feat_down_s1
feat_agg = torch.cat((feat_agg, feat_origin), dim=1)
feat_agg = self.end_gcn(graph_scale1, feat_agg)
return feat_agg, logit_s1, logit_s2
class GraphCrossNet(torch.nn.Module):
"""
Description
-----------
The Graph Cross Network.
Parameters
----------
in_dim : int
The number of input node feature channels.
out_dim : int
The number of output node feature channels.
edge_feat_dim : int, optional
The number of input edge feature channels. Edge feature
will be passed to a Linear layer and concatenated to
input node features. Default: :obj:`0`
hidden_dim : int, optional
The number of hidden node feature channels.
Default: :obj:`96`
pool_ratios : Union[float, List[float]], optional
The pooling ratios (for keeping nodes) for each layer.
For example, if `pool_ratio=0.8`, 80\% nodes will be preserved.
If a single float number is given, all pooling layers will have the
same pooling ratio.
Default: :obj:`[0.9, 0.7]`
readout_nodes : int, optional
Number of nodes perserved in the final sort pool operation.
Default: :obj:`30`
conv1d_dims : List[int], optional
The number of kernels of Conv1d operations.
Default: :obj:`[16, 32]`
conv1d_kws : List[int], optional
The kernel size of Conv1d.
Default: :obj:`[5]`
cross_weight : float, optional
The weight parameter used in graph cross layers
Default: :obj:`1.0`
fuse_weight : float, optional
The weight parameter used at the end of GXN for channel fusion.
Default: :obj:`1.0`
"""
def __init__(self, in_dim:int, out_dim:int, edge_feat_dim:int=0,
hidden_dim:int=96, pool_ratios:Union[List[float], float]=[0.9, 0.7],
readout_nodes:int=30, conv1d_dims:List[int]=[16, 32],
conv1d_kws:List[int]=[5],
cross_weight:float=1., fuse_weight:float=1., dist:int=1):
super(GraphCrossNet, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.hidden_dim = hidden_dim
self.edge_feat_dim = edge_feat_dim
self.readout_nodes = readout_nodes
conv1d_kws = [hidden_dim] + conv1d_kws
if edge_feat_dim > 0:
self.in_dim += hidden_dim
self.e2l_lin = torch.nn.Linear(edge_feat_dim, hidden_dim)
else:
self.e2l_lin = None
self.gxn = GraphCrossModule(pool_ratios, in_dim=self.in_dim, out_dim=hidden_dim,
hidden_dim=hidden_dim//2, cross_weight=cross_weight,
fuse_weight=fuse_weight, dist=dist)
self.sortpool = SortPooling(readout_nodes)
# final updates
self.final_conv1 = torch.nn.Conv1d(1, conv1d_dims[0],
kernel_size=conv1d_kws[0],
stride=conv1d_kws[0])
self.final_maxpool = torch.nn.MaxPool1d(2, 2)
self.final_conv2 = torch.nn.Conv1d(conv1d_dims[0], conv1d_dims[1],
kernel_size=conv1d_kws[1], stride=1)
self.final_dense_dim = int((readout_nodes - 2) / 2 + 1)
self.final_dense_dim = (self.final_dense_dim - conv1d_kws[1] + 1) * conv1d_dims[1]
if self.out_dim > 0:
self.out_lin = torch.nn.Linear(self.final_dense_dim, out_dim)
self.init_weights()
def init_weights(self):
if self.e2l_lin is not None:
torch.nn.init.xavier_normal_(self.e2l_lin.weight)
torch.nn.init.xavier_normal_(self.final_conv1.weight)
torch.nn.init.xavier_normal_(self.final_conv2.weight)
if self.out_dim > 0:
torch.nn.init.xavier_normal_(self.out_lin.weight)
def forward(self, graph:DGLGraph, node_feat:Tensor, edge_feat:Optional[Tensor]=None):
num_batch = graph.batch_size
if edge_feat is not None:
edge_feat = self.e2l_lin(edge_feat)
with graph.local_scope():
graph.edata["he"] = edge_feat
graph.update_all(fn.copy_edge("he", "m"), fn.sum("m", "hn"))
edge2node_feat = graph.ndata.pop("hn")
node_feat = torch.cat((node_feat, edge2node_feat), dim=1)
node_feat, logits1, logits2 = self.gxn(graph, node_feat)
batch_sortpool_feats = self.sortpool(graph, node_feat)
# final updates
to_conv1d = batch_sortpool_feats.unsqueeze(1)
conv1d_result = F.relu(self.final_conv1(to_conv1d))
conv1d_result = self.final_maxpool(conv1d_result)
conv1d_result = F.relu(self.final_conv2(conv1d_result))
to_dense = conv1d_result.view(num_batch, -1)
if self.out_dim > 0:
out = F.relu(self.out_lin(to_dense))
else:
out = to_dense
return out, logits1, logits2
class GraphClassifier(torch.nn.Module):
"""
Description
-----------
Graph Classifier for graph classification.
GXN + MLP
"""
def __init__(self, args):
super(GraphClassifier, self).__init__()
self.gxn = GraphCrossNet(in_dim=args.in_dim,
out_dim=args.embed_dim,
edge_feat_dim=args.edge_feat_dim,
hidden_dim=args.hidden_dim,
pool_ratios=args.pool_ratios,
readout_nodes=args.readout_nodes,
conv1d_dims=args.conv1d_dims,
conv1d_kws=args.conv1d_kws,
cross_weight=args.cross_weight,
fuse_weight=args.fuse_weight)
self.lin1 = torch.nn.Linear(args.embed_dim, args.final_dense_hidden_dim)
self.lin2 = torch.nn.Linear(args.final_dense_hidden_dim, args.out_dim)
self.dropout = args.dropout
def forward(self, graph:DGLGraph, node_feat:Tensor, edge_feat:Optional[Tensor]=None):
embed, logits1, logits2 = self.gxn(graph, node_feat, edge_feat)
logits = F.relu(self.lin1(embed))
if self.dropout > 0:
logits = F.dropout(logits, p=self.dropout, training=self.training)
logits = self.lin2(logits)
return F.log_softmax(logits, dim=1), logits1, logits2
#!/bin/bash
# input arguments
DATA="${1-DD}" # ENZYMES, DD, PROTEINS, COLLAB, IMDB-BINARY, IMDB-MULTI
device=${2-0}
num_trials=${3-10}
print_every=${4-10}
# general settings
hidden_gxn=96
k1=0.8
k2=0.7
sortpooling_k=30
hidden_final=128
batch_size=64
dropout=0.5
cross_weight=1.0
fuse_weight=0.9
weight_decay=1e-3
# dataset-specific settings
case ${DATA} in
IMDB-BINARY)
num_epochs=200
learning_rate=0.001
sortpooling_k=31
k1=0.8
k2=0.5
;;
IMDB-MULTI)
num_epochs=200
learning_rate=0.001
sortpooling_k=22
k1=0.8
k2=0.7
;;
COLLAB)
num_epochs=100
learning_rate=0.001
sortpooling_k=130
k1=0.9
k2=0.5
;;
DD)
num_epochs=100
learning_rate=0.0005
sortpooling_k=291
k1=0.8
k2=0.6
;;
PROTEINS)
num_epochs=100
learning_rate=0.001
sortpooling_k=32
k1=0.8
k2=0.7
;;
ENZYMES)
num_epochs=500
learning_rate=0.0001
sortpooling_k=42
k1=0.7
k2=0.5
;;
*)
num_epochs=500
learning_rate=0.00001
;;
esac
python main.py \
--dataset $DATA \
--lr $learning_rate \
--epochs $num_epochs \
--hidden_dim $hidden_gxn \
--final_dense_hidden_dim $hidden_final \
--readout_nodes $sortpooling_k \
--pool_ratios $k1 $k2 \
--batch_size $batch_size \
--device $device \
--dropout $dropout \
--cross_weight $cross_weight\
--fuse_weight $fuse_weight\
--weight_decay $weight_decay\
--num_trials $num_trials\
--print_every $print_every\
#!/bin/bash
# input arguments
DATA="${1-DD}" # ENZYMES, DD, PROTEINS, COLLAB, IMDB-BINARY, IMDB-MULTI
device=${2-0}
num_trials=${3-10}
print_every=${4-10}
# general settings
hidden_gxn=96
k1=0.8
k2=0.7
sortpooling_k=30
hidden_final=128
batch_size=64
dropout=0.5
cross_weight=1.0
fuse_weight=0.9
weight_decay=1e-3
# dataset-specific settings
case ${DATA} in
IMDB-BINARY)
num_epochs=200
patience=40
learning_rate=0.001
sortpooling_k=31
k1=0.8
k2=0.5
;;
IMDB-MULTI)
num_epochs=200
patience=40
learning_rate=0.001
sortpooling_k=22
k1=0.8
k2=0.7
;;
COLLAB)
num_epochs=100
patience=20
learning_rate=0.001
sortpooling_k=130
k1=0.9
k2=0.5
;;
DD)
num_epochs=100
patience=20
learning_rate=0.0005
sortpooling_k=291
k1=0.8
k2=0.6
;;
PROTEINS)
num_epochs=100
patience=20
learning_rate=0.001
sortpooling_k=32
k1=0.8
k2=0.7
;;
ENZYMES)
num_epochs=500
patience=100
learning_rate=0.0001
sortpooling_k=42
k1=0.7
k2=0.5
;;
*)
num_epochs=500
patience=100
learning_rate=0.00001
;;
esac
python main_early_stop.py \
--dataset $DATA \
--lr $learning_rate \
--epochs $num_epochs \
--hidden_dim $hidden_gxn \
--final_dense_hidden_dim $hidden_final \
--readout_nodes $sortpooling_k \
--pool_ratios $k1 $k2 \
--batch_size $batch_size \
--device $device \
--dropout $dropout \
--cross_weight $cross_weight\
--fuse_weight $fuse_weight\
--weight_decay $weight_decay\
--num_trials $num_trials\
--print_every $print_every\
--patience $patience\
import argparse
import logging
import math
import os
import random
import numpy as np
import torch
import torch.cuda
from scipy.stats import t
def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False):
"""Compute mean and standard deviation from an numerical array
Args:
array (array like obj): The numerical array, this array can be
convert to :obj:`torch.Tensor`.
conf_interval (bool, optional): If True, compute the confidence interval bound (95%)
instead of the std value. (default: :obj:`False`)
name (str, optional): The name of this numerical array, for log usage.
(default: :obj:`None`)
stdout (bool, optional): Whether to output result to the terminal.
(default: :obj:`False`)
logout (bool, optional): Whether to output result via logging module.
(default: :obj:`False`)
"""
eps = 1e-9
array = torch.Tensor(array)
std, mean = torch.std_mean(array)
std = std.item()
mean = mean.item()
center = mean
if conf_interval:
n = array.size(0)
se = std / (math.sqrt(n) + eps)
t_value = t.ppf(0.975, df=n-1)
err_bound = t_value * se
else:
err_bound = std
# log and print
if name is None:
name = "array {}".format(id(array))
log = "{}: {:.4f}(+-{:.4f})".format(name, center, err_bound)
if stdout:
print(log)
if logout:
logging.info(log)
return center, err_bound
def parse_args():
parser = argparse.ArgumentParser("Graph Cross Network")
parser.add_argument("--pool_ratios", nargs="+", type=float,
help="The pooling ratios used in graph cross layers")
parser.add_argument("--hidden_dim", type=int, default=96,
help="The number of hidden channels in GXN")
parser.add_argument("--cross_weight", type=float, default=1.,
help="Weight parameter used in graph cross layer")
parser.add_argument("--fuse_weight", type=float, default=1.,
help="Weight parameter for feature fusion")
parser.add_argument("--num_cross_layers", type=int, default=2,
help="The number of graph corss layers")
parser.add_argument("--readout_nodes", type=int, default=30,
help="Number of nodes for each graph after final graph pooling")
parser.add_argument("--conv1d_dims", nargs="+", type=int,
help="Number of channels in conv operations in the end of graph cross net")
parser.add_argument("--conv1d_kws", nargs="+", type=int,
help="Kernel sizes of conv1d operations")
parser.add_argument("--dropout", type=float, default=0.,
help="Dropout rate")
parser.add_argument("--embed_dim", type=int, default=1024,
help="Number of channels of graph embedding")
parser.add_argument("--final_dense_hidden_dim", type=int, default=128,
help="The number of hidden channels in final dense layers")
parser.add_argument("--batch_size", type=int, default=64,
help="Batch size")
parser.add_argument("--lr", type=float, default=1e-4,
help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=0.,
help="Weight decay rate")
parser.add_argument("--epochs", type=int, default=1000,
help="Number of training epochs")
parser.add_argument("--patience", type=int, default=20,
help="Patience for early stopping")
parser.add_argument("--num_trials", type=int, default=1,
help="Number of trials")
parser.add_argument("--device", type=int, default=0,
help="Computation device id, -1 for cpu")
parser.add_argument("--dataset", type=str, default="DD",
help="Dataset used for training")
parser.add_argument("--seed", type=int, default=-1,
help="Random seed, -1 for unset")
parser.add_argument("--print_every", type=int, default=10,
help="Print train log every ? epochs, -1 for silence training")
parser.add_argument("--dataset_path", type=str, default="./datasets",
help="Path holding your dataset")
parser.add_argument("--output_path", type=str, default="./output",
help="Path holding your result files")
args = parser.parse_args()
# default value for list hyper-parameters
if not args.pool_ratios or len(args.pool_ratios) < 2:
args.pool_ratios = [0.8, 0.7]
logging.warning("No valid pool_ratios is given, "
"using default value '{}'".format(args.pool_ratios))
if not args.conv1d_dims or len(args.conv1d_dims) < 2:
args.conv1d_dims = [16, 32]
logging.warning("No valid conv1d_dims is give, "
"using default value {}".format(args.conv1d_dims))
if not args.conv1d_kws or len(args.conv1d_kws) < 1:
args.conv1d_kws = [5]
logging.warning("No valid conv1d_kws is given, "
"using default value '{}'".format(args.conv1d_kws))
# device
args.device = "cpu" if args.device < 0 else "cuda:{}".format(args.device)
if not torch.cuda.is_available():
logging.warning("GPU is not available, using CPU for training")
args.device = "cpu"
else:
logging.warning("Device: {}".format(args.device))
# random seed
if args.seed >= 0:
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
if args.device != "cpu":
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# print every
if args.print_every < 0:
args.print_every = args.epochs + 1
# path
paths = [args.output_path, args.dataset_path]
for p in paths:
if not os.path.exists(p):
os.makedirs(p)
# datasets ad-hoc
if args.dataset in ['COLLAB', 'IMDB-BINARY', 'IMDB-MULTI', 'ENZYMES']:
args.degree_as_feature = True
else:
args.degree_as_feature = False
return args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment