Commit 684a61ad authored by HQ's avatar HQ Committed by VoVAllen
Browse files

[Model] DiffPool with both DGL and tensorized operatons (#665)

* diffpool original file added

* make diffpool fuse up and running

* minor tweak on tu dataset statistics method

* fix tu

* break

* delete break

* pre_org

* diffpool fuse reorg

* fix random shuffling

* fix bn

* add dgl layers

* early stopping

* add readme

* fix

* add diffpool preprocess script

* tweak tu dataset

* tweak

* tweak

* tweak

* tweak

* tweak

* preprocess dataset

* fix early stopping

* fix

* fix

* fix

* tweak

* readme

* code review

* code review

* dataset code review

* update README

* code review

* tu doc
parent fc2e166b
Hierarchical Graph Representation Learning with Differentiable Pooling
============
Paper link: [https://arxiv.org/abs/1806.08804](https://arxiv.org/abs/1806.08804)
Author's code repo: [https://github.com/RexYing/diffpool](https://github.com/RexYing/diffpool)
This folder contains a DGL implementation of the DiffPool model. The first pooling layer is computed with DGL, and following pooling layers are computed with tensorized operation since the pooled graphs are dense.
Dependencies
------------
* PyTorch 1.0+
How to run
----------
```bash
python train.py --dataset ENZYMES --pool_ratio 0.10 --num_pool 1
python train.py --dataset DD --pool_ratio 0.15 --num_pool 1
```
Performance
-----------
ENZYMES 63.33% (with early stopping)
DD 79.31% (with early stopping)
## Dependencies
import numpy as np
import torch
def one_hotify(labels, pad=-1):
'''
cast label to one hot vector
'''
num_instances = len(labels)
if pad <= 0:
dim_embedding = np.max(labels) + 1 #zero-indexed assumed
else:
assert pad > 0, "result_dim for padding one hot embedding not set!"
dim_embedding = pad + 1
embeddings = np.zeros((num_instances, dim_embedding))
embeddings[np.arange(num_instances), labels] = 1
return embeddings
def pre_process(dataset, prog_args):
"""
diffpool specific data partition, pre-process and shuffling
"""
if prog_args.data_mode != "default":
print("overwrite node attributes with DiffPool's preprocess setting")
if prog_args.data_mode == 'id':
for g, _ in dataset:
id_list = np.arange(g.number_of_nodes())
g.ndata['feat'] = one_hotify(id_list, pad=dataset.max_num_node)
elif prog_args.data_mode == 'deg-num':
for g, _ in dataset:
g.ndata['feat'] = np.expand_dims(g.in_degrees(), axis=1)
elif prog_args.data_mode == 'deg':
for g in dataset:
degs = list(g.in_degrees())
degs_one_hot = one_hotify(degs, pad=dataset.max_degrees)
g.ndata['feat'] = degs_one_hot
\ No newline at end of file
from .gnn import GraphSage, GraphSageLayer, DiffPoolBatchedGraphLayer
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
class Aggregator(nn.Module):
"""
Base Aggregator class. Adapting
from PR# 403
This class is not supposed to be called
"""
def __init__(self):
super(Aggregator, self).__init__()
def forward(self, node):
neighbour = node.mailbox['m']
c = self.aggre(neighbour)
return {"c": c}
def aggre(self, neighbour):
# N x F
raise NotImplementedError
class MeanAggregator(Aggregator):
'''
Mean Aggregator for graphsage
'''
def __init__(self):
super(MeanAggregator, self).__init__()
def aggre(self, neighbour):
mean_neighbour = torch.mean(neighbour, dim=1)
return mean_neighbour
class MaxPoolAggregator(Aggregator):
'''
Maxpooling aggregator for graphsage
'''
def __init__(self, in_feats, out_feats, activation, bias):
super(MaxPoolAggregator, self).__init__()
self.linear = nn.Linear(in_feats, out_feats, bias=bias)
self.activation = activation
#Xavier initialization of weight
nn.init.xavier_uniform_(self.linear.weight,
gain=nn.init.calculate_gain('relu'))
def aggre(self, neighbour):
neighbour = self.linear(neighbour)
if self.activation:
neighbour = self.activation(neighbour)
maxpool_neighbour = torch.max(neighbour, dim=1)[0]
return maxpool_neighbour
class LSTMAggregator(Aggregator):
'''
LSTM aggregator for graphsage
'''
def __init__(self, in_feats, hidden_feats):
super(LSTMAggregator, self).__init__()
self.lstm = nn.LSTM(in_feats, hidden_feats, batch_first=True)
self.hidden_dim = hidden_feats
self.hidden = self.init_hidden()
nn.init.xavier_uniform_(self.lstm.weight,
gain=nn.init.calculate_gain('relu'))
def init_hidden(self):
"""
Defaulted to initialite all zero
"""
return (torch.zeros(1, 1, self.hidden_dim),
torch.zeros(1, 1, self.hidden_dim))
def aggre(self, neighbours):
'''
aggregation function
'''
# N X F
rand_order = torch.randperm(neighbours.size()[1])
neighbours = neighbours[:, rand_order, :]
(lstm_out, self.hidden) = self.lstm(neighbours.view(neighbours.size()[0],
neighbours.size()[1],
-1))
return lstm_out[:, -1, :]
def forward(self, node):
neighbour = node.mailbox['m']
c = self.aggre(neighbour)
return {"c":c}
import torch
import torch.nn as nn
import torch.nn.functional as F
class Bundler(nn.Module):
"""
Bundler, which will be the node_apply function in DGL paradigm
"""
def __init__(self, in_feats, out_feats, activation, dropout, bias=True):
super(Bundler, self).__init__()
self.dropout = nn.Dropout(p=dropout)
self.linear = nn.Linear(in_feats*2, out_feats, bias)
self.activation = activation
nn.init.xavier_uniform_(self.linear.weight,
gain=nn.init.calculate_gain('relu'))
def concat(self, h, aggre_result):
bundle = torch.cat((h, aggre_result),1)
bundle = self.linear(bundle)
return bundle
def forward(self, node):
h = node.data['h']
c = node.data['c']
bundle = self.concat(h, c)
bundle = F.normalize(bundle, p=2, dim=1)
if self.activation:
bundle = self.activation(bundle)
return {"h":bundle}
import torch
import torch.nn as nn
import numpy as np
from scipy.linalg import block_diag
import dgl.function as fn
from .aggregator import MaxPoolAggregator, MeanAggregator, LSTMAggregator
from .bundler import Bundler
from ..model_utils import masked_softmax
from model.loss import EntropyLoss
class GraphSageLayer(nn.Module):
"""
GraphSage layer in Inductive learning paper by hamilton
Here, graphsage layer is a reduced function in DGL framework
"""
def __init__(self, in_feats, out_feats, activation, dropout,
aggregator_type, bn=False, bias=True):
super(GraphSageLayer, self).__init__()
self.use_bn = bn
self.bundler = Bundler(in_feats, out_feats, activation, dropout,
bias=bias)
self.dropout = nn.Dropout(p=dropout)
if aggregator_type == "maxpool":
self.aggregator = MaxPoolAggregator(in_feats, in_feats,
activation, bias)
elif aggregator_type == "lstm":
self.aggregator = LSTMAggregator(in_feats, in_feats)
else:
self.aggregator = MeanAggregator()
def forward(self, g, h):
h = self.dropout(h)
g.ndata['h'] = h
if self.use_bn and not hasattr(self, 'bn'):
device = h.device
self.bn = nn.BatchNorm1d(h.size()[1]).to(device)
g.update_all(fn.copy_src(src='h', out='m'), self.aggregator,
self.bundler)
if self.use_bn:
h = self.bn(h)
h = g.ndata.pop('h')
return h
class GraphSage(nn.Module):
"""
Grahpsage network that concatenate several graphsage layer
"""
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation,
dropout, aggregator_type):
super(GraphSage, self).__init__()
self.layers = nn.ModuleList()
#input layer
self.layers.append(GraphSageLayer(in_feats, n_hidden, activation, dropout,
aggregator_type))
# hidden layers
for _ in range(n_layers -1):
self.layers.append(GraphSageLayer(n_hidden, n_hidden, activation,
dropout, aggregator_type))
#output layer
self.layers.append(GraphSageLayer(n_hidden, n_classes, None,
dropout, aggregator_type))
def forward(self, g, features):
h = features
for layer in self.layers:
h = layer(g, h)
return h
class DiffPoolBatchedGraphLayer(nn.Module):
def __init__(self, input_dim, assign_dim, output_feat_dim, activation, dropout, aggregator_type, link_pred):
super(DiffPoolBatchedGraphLayer, self).__init__()
self.embedding_dim = input_dim
self.assign_dim = assign_dim
self.hidden_dim = output_feat_dim
self.link_pred = link_pred
self.feat_gc = GraphSageLayer(input_dim, output_feat_dim, activation, dropout, aggregator_type)
self.pool_gc = GraphSageLayer(input_dim, assign_dim, activation, dropout, aggregator_type)
self.reg_loss = nn.ModuleList([])
self.loss_log = {}
self.reg_loss.append(EntropyLoss())
def forward(self, g, h):
feat = self.feat_gc(g, h)
assign_tensor = self.pool_gc(g,h)
device = feat.device
assign_tensor_masks = []
batch_size = len(g.batch_num_nodes)
for g_n_nodes in g.batch_num_nodes:
mask =torch.ones((g_n_nodes,
int(assign_tensor.size()[1]/batch_size)))
assign_tensor_masks.append(mask)
"""
The first pooling layer is computed on batched graph.
We first take the adjacency matrix of the batched graph, which is block-wise diagonal.
We then compute the assignment matrix for the whole batch graph, which will also be block diagonal
"""
mask = torch.FloatTensor(block_diag(*assign_tensor_masks)).to(device=device)
assign_tensor = masked_softmax(assign_tensor, mask,
memory_efficient=False)
h = torch.matmul(torch.t(assign_tensor),feat)
adj = g.adjacency_matrix(ctx=device)
adj_new = torch.sparse.mm(adj, assign_tensor)
adj_new = torch.mm(torch.t(assign_tensor), adj_new)
if self.link_pred:
current_lp_loss = torch.norm(adj.to_dense() -\
torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(),2)
self.loss_log['LinkPredLoss'] = current_lp_loss
for loss_layer in self.reg_loss:
loss_name = str(type(loss_layer).__name__)
self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor)
return adj_new, h
\ No newline at end of file
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np
from scipy.linalg import block_diag
import dgl
from .dgl_layers import GraphSage, GraphSageLayer, DiffPoolBatchedGraphLayer
from .tensorized_layers import *
from .model_utils import batch2tensor
import time
class DiffPool(nn.Module):
"""
DiffPool Fuse
"""
def __init__(self, input_dim, hidden_dim, embedding_dim,
label_dim, activation, n_layers, dropout,
n_pooling, linkpred, batch_size, aggregator_type,
assign_dim, pool_ratio, cat=False):
super(DiffPool, self).__init__()
self.link_pred = linkpred
self.concat = cat
self.n_pooling = n_pooling
self.batch_size = batch_size
self.link_pred_loss = []
self.entropy_loss = []
# list of GNN modules before the first diffpool operation
self.gc_before_pool = nn.ModuleList()
self.diffpool_layers = nn.ModuleList()
# list of list of GNN modules, each list after one diffpool operation
self.gc_after_pool = nn.ModuleList()
self.assign_dim = assign_dim
self.bn = True
self.num_aggs = 1
# constructing layers
# layers before diffpool
assert n_layers >= 3, "n_layers too few"
self.gc_before_pool.append(GraphSageLayer(input_dim, hidden_dim, activation, dropout, aggregator_type, self.bn))
for _ in range(n_layers - 2):
self.gc_before_pool.append(GraphSageLayer(hidden_dim, hidden_dim, activation, dropout, aggregator_type, self.bn))
self.gc_before_pool.append(GraphSageLayer(hidden_dim, embedding_dim, None, dropout, aggregator_type))
assign_dims = []
assign_dims.append(self.assign_dim)
if self.concat:
# diffpool layer receive pool_emedding_dim node feature tensor
# and return pool_embedding_dim node embedding
pool_embedding_dim = hidden_dim * (n_layers -1) + embedding_dim
else:
pool_embedding_dim = embedding_dim
self.first_diffpool_layer = DiffPoolBatchedGraphLayer(pool_embedding_dim, self.assign_dim, hidden_dim,activation, dropout, aggregator_type, self.link_pred)
gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1):
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, hidden_dim))
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, embedding_dim))
self.gc_after_pool.append(gc_after_per_pool)
self.assign_dim = int(self.assign_dim * pool_ratio)
# each pooling module
for _ in range(n_pooling-1):
self.diffpool_layers.append(BatchedDiffPool(pool_embedding_dim, self.assign_dim, hidden_dim, self.link_pred))
gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1):
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, hidden_dim))
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, embedding_dim))
self.gc_after_pool.append(gc_after_per_pool)
assign_dims.append(self.assign_dim)
self.assign_dim = int(self.assign_dim * pool_ratio)
# predicting layer
if self.concat:
self.pred_input_dim = pool_embedding_dim*self.num_aggs*(n_pooling+1)
else:
self.pred_input_dim = embedding_dim*self.num_aggs
self.pred_layer = nn.Linear(self.pred_input_dim, label_dim)
# weight initialization
for m in self.modules():
if isinstance(m, nn.Linear):
m.weight.data = init.xavier_uniform_(m.weight.data,
gain=nn.init.calculate_gain('relu'))
if m.bias is not None:
m.bias.data = init.constant_(m.bias.data, 0.0)
def gcn_forward(self, g, h, gc_layers, cat=False):
"""
Return gc_layer embedding cat.
"""
block_readout = []
for gc_layer in gc_layers[:-1]:
h = gc_layer(g, h)
block_readout.append(h)
h = gc_layers[-1](g, h)
block_readout.append(h)
if cat:
block = torch.cat(block_readout, dim=1) # N x F, F = F1 + F2 + ...
else:
block = h
return block
def gcn_forward_tensorized(self, h, adj, gc_layers, cat=False):
block_readout = []
for gc_layer in gc_layers:
h = gc_layer(h, adj)
block_readout.append(h)
if cat:
block = torch.cat(block_readout, dim=2) # N x F, F = F1 + F2 + ...
else:
block = h
return block
def forward(self, g):
self.link_pred_loss = []
self.entropy_loss = []
h = g.ndata['feat']
# node feature for assignment matrix computation is the same as the
# original node feature
h_a = h
out_all = []
# we use GCN blocks to get an embedding first
g_embedding = self.gcn_forward(g, h, self.gc_before_pool, self.concat)
g.ndata['h'] = g_embedding
readout = dgl.sum_nodes(g, 'h')
out_all.append(readout)
if self.num_aggs == 2:
readout = dgl.max_nodes(g, 'h')
out_all.append(readout)
adj, h = self.first_diffpool_layer(g, g_embedding)
node_per_pool_graph = int(adj.size()[0] / self.batch_size)
h, adj = batch2tensor(adj, h, node_per_pool_graph)
h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[0], self.concat)
readout = torch.sum(h, dim=1)
out_all.append(readout)
if self.num_aggs == 2:
readout, _ = torch.max(h, dim=1)
out_all.append(readout)
for i, diffpool_layer in enumerate(self.diffpool_layers):
h, adj = diffpool_layer(h, adj)
h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[i+1], self.concat)
readout = torch.sum(h, dim=1)
out_all.append(readout)
if self.num_aggs == 2:
readout, _ = torch.max(h, dim=1)
out_all.append(readout)
if self.concat or self.num_aggs > 1:
final_readout = torch.cat(out_all, dim=1)
else:
final_readout = readout
ypred = self.pred_layer(final_readout)
return ypred
def loss(self, pred, label):
'''
loss function
'''
#softmax + CE
criterion = nn.CrossEntropyLoss()
loss = criterion(pred, label)
for diffpool_layer in self.diffpool_layers:
for key, value in diffpool_layer.loss_log.items():
loss += value
return loss
import torch
import torch.nn as nn
class EntropyLoss(nn.Module):
# Return Scalar
def forward(self, adj, anext, s_l):
entropy = (torch.distributions.Categorical(probs=s_l).entropy()).sum(-1).mean(-1)
assert not torch.isnan(entropy)
return entropy
class LinkPredLoss(nn.Module):
def forward(self, adj, anext, s_l):
link_pred_loss = (adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2))
link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))
return link_pred_loss.mean()
import torch as th
from torch.autograd import Function
def batch2tensor(batch_adj, batch_feat, node_per_pool_graph):
"""
transform a batched graph to batched adjacency tensor and node feature tensor
"""
batch_size = int(batch_adj.size()[0] / node_per_pool_graph)
adj_list = []
feat_list = []
for i in range(batch_size):
start = i*node_per_pool_graph
end = (i+1)*node_per_pool_graph
adj_list.append(batch_adj[start:end,start:end])
feat_list.append(batch_feat[start:end,:])
adj_list = list(map(lambda x : th.unsqueeze(x, 0), adj_list))
feat_list = list(map(lambda x : th.unsqueeze(x, 0), feat_list))
adj = th.cat(adj_list,dim=0)
feat = th.cat(feat_list, dim=0)
return feat, adj
def masked_softmax(matrix, mask, dim=-1, memory_efficient=True,
mask_fill_value=-1e32):
'''
masked_softmax for dgl batch graph
code snippet contributed by AllenNLP (https://github.com/allenai/allennlp)
'''
if mask is None:
result = th.nn.functional.softmax(matrix, dim=dim)
else:
mask = mask.float()
while mask.dim() < matrix.dim():
mask = mask.unsqueeze(1)
if not memory_efficient:
result = th.nn.functional.softmax(matrix * mask, dim=dim)
result = result * mask
result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
else:
masked_matrix = matrix.masked_fill((1-mask).byte(),
mask_fill_value)
result = th.nn.functional.softmax(masked_matrix, dim=dim)
return result
from .diffpool import BatchedDiffPool
from .graphsage import BatchedGraphSAGE
\ No newline at end of file
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
from model.tensorized_layers.graphsage import BatchedGraphSAGE
class DiffPoolAssignment(nn.Module):
def __init__(self, nfeat, nnext):
super().__init__()
self.assign_mat = BatchedGraphSAGE(nfeat, nnext, use_bn=True)
def forward(self, x, adj, log=False):
s_l_init = self.assign_mat(x, adj)
s_l = F.softmax(s_l_init, dim=-1)
return s_l
import torch
from torch import nn as nn
from model.tensorized_layers.assignment import DiffPoolAssignment
from model.tensorized_layers.graphsage import BatchedGraphSAGE
from model.loss import EntropyLoss, LinkPredLoss
class BatchedDiffPool(nn.Module):
def __init__(self, nfeat, nnext, nhid, link_pred=False, entropy=True):
super(BatchedDiffPool, self).__init__()
self.link_pred = link_pred
self.log = {}
self.link_pred_layer = LinkPredLoss()
self.embed = BatchedGraphSAGE(nfeat, nhid, use_bn=True)
self.assign = DiffPoolAssignment(nfeat, nnext)
self.reg_loss = nn.ModuleList([])
self.loss_log = {}
if link_pred:
self.reg_loss.append(LinkPredLoss())
if entropy:
self.reg_loss.append(EntropyLoss())
def forward(self, x, adj, log=False):
z_l = self.embed(x, adj)
s_l = self.assign(x, adj)
if log:
self.log['s'] = s_l.cpu().numpy()
xnext = torch.matmul(s_l.transpose(-1, -2), z_l)
anext = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l)
for loss_layer in self.reg_loss:
loss_name = str(type(loss_layer).__name__)
self.loss_log[loss_name] = loss_layer(adj, anext, s_l)
if log:
self.log['a'] = anext.cpu().numpy()
return xnext, anext
import torch
from torch import nn as nn
from torch.nn import functional as F
class BatchedGraphSAGE(nn.Module):
def __init__(self, infeat, outfeat, use_bn=True, mean=False, add_self=False):
super().__init__()
self.add_self = add_self
self.use_bn = use_bn
self.mean = mean
self.W = nn.Linear(infeat, outfeat, bias=True)
nn.init.xavier_uniform_(self.W.weight, gain=nn.init.calculate_gain('relu'))
def forward(self, x, adj):
if self.use_bn and not hasattr(self, 'bn'):
self.bn = nn.BatchNorm1d(adj.size(1)).to(adj.device)
if self.add_self:
adj = adj + torch.eye(adj.size(0)).to(adj.device)
if self.mean:
adj = adj / adj.sum(1, keepdim=True)
h_k_N = torch.matmul(adj, x)
h_k = self.W(h_k_N)
h_k = F.normalize(h_k, dim=2, p=2)
h_k = F.relu(h_k)
if self.use_bn:
h_k = self.bn(h_k)
return h_k
def __repr__(self):
if self.use_bn:
return 'BN' + super(BatchedGraphSAGE, self).__repr__()
else:
return super(BatchedGraphSAGE, self).__repr__()
import os
import numpy as np
import torch
import dgl
import networkx as nx
import argparse
import random
import time
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import tu
from model.encoder import DiffPool
from data_utils import pre_process
def arg_parse():
'''
argument parser
'''
parser = argparse.ArgumentParser(description='DiffPool arguments')
parser.add_argument('--dataset', dest='dataset', help='Input Dataset')
parser.add_argument('--pool_ratio', dest='pool_ratio', type=float, help='pooling ratio')
parser.add_argument('--num_pool', dest='num_pool', type=int, help='num_pooling layer')
parser.add_argument('--no_link_pred', dest='linkpred', action='store_false',
help='switch of link prediction object')
parser.add_argument('--cuda', dest='cuda', type=int, help='switch cuda')
parser.add_argument('--lr', dest='lr', type=float, help='learning rate')
parser.add_argument('--clip', dest='clip', type=float, help='gradient clipping')
parser.add_argument('--batch-size', dest='batch_size', type=int, help='batch size')
parser.add_argument('--epochs', dest='epoch', type=int,
help='num-of-epoch')
parser.add_argument('--train-ratio', dest='train_ratio', type=float,
help='ratio of trainning dataset split')
parser.add_argument('--test-ratio', dest='test_ratio', type=float,
help='ratio of testing dataset split')
parser.add_argument('--num_workers', dest='n_worker', type=int,
help='number of workers when dataloading')
parser.add_argument('--gc-per-block', dest='gc_per_block', type=int,
help='number of graph conv layer per block')
parser.add_argument('--bn', dest='bn', action='store_const', const=True,
default=True, help='switch for bn')
parser.add_argument('--dropout', dest='dropout', type=float,
help='dropout rate')
parser.add_argument('--bias', dest='bias', action='store_const',
const=True, default=True, help='switch for bias')
parser.add_argument('--save_dir', dest='save_dir', help='model saving directory: SAVE_DICT/DATASET')
parser.add_argument('--load_epoch', dest='load_epoch', help='load trained model params from\
SAVE_DICT/DATASET/model-LOAD_EPOCH')
parser.add_argument('--data_mode', dest='data_mode', help='data\
preprocessing mode: default, id, degree, or one-hot\
vector of degree number', choices=['default', 'id', 'deg',
'deg_num'])
parser.set_defaults(dataset='ENZYMES',
pool_ratio=0.15,
num_pool=1,
cuda=1,
lr=1e-3,
clip=2.0,
batch_size=20,
epoch=4000,
train_ratio=0.7,
test_ratio=0.1,
n_worker=1,
gc_per_block=3,
dropout=0.0,
method='diffpool',
bn=True,
bias=True,
save_dir="./model_param",
load_epoch=-1,
data_mode = 'default')
return parser.parse_args()
def prepare_data(dataset, prog_args, train=False, pre_process=None):
'''
preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
'''
if train:
shuffle = True
else:
shuffle = False
if pre_process:
pre_process(dataset, prog_args)
#dataset.set_fold(fold)
return torch.utils.data.DataLoader(dataset,
batch_size=prog_args.batch_size,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=True,
num_workers=prog_args.n_worker)
def graph_classify_task(prog_args):
'''
perform graph classification task
'''
dataset = tu.TUDataset(name=prog_args.dataset)
train_size = int(prog_args.train_ratio * len(dataset))
test_size = int(prog_args.test_ratio * len(dataset))
val_size = int(len(dataset) - train_size - test_size)
dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(dataset, (train_size, val_size, test_size))
train_dataloader = prepare_data(dataset_train, prog_args, train=True,
pre_process=pre_process)
val_dataloader = prepare_data(dataset_val, prog_args, train=False,
pre_process=pre_process)
test_dataloader = prepare_data(dataset_test, prog_args, train=False,
pre_process=pre_process)
input_dim, label_dim, max_num_node = dataset.statistics()
print("++++++++++STATISTICS ABOUT THE DATASET")
print("dataset feature dimension is", input_dim)
print("dataset label dimension is", label_dim)
print("the max num node is", max_num_node)
print("number of graphs is", len(dataset))
# assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size"
hidden_dim = 64 # used to be 64
embedding_dim = 64
# calculate assignment dimension: pool_ratio * largest graph's maximum
# number of nodes in the dataset
assign_dim = int(max_num_node * prog_args.pool_ratio) * prog_args.batch_size
print("++++++++++MODEL STATISTICS++++++++")
print("model hidden dim is", hidden_dim)
print("model embedding dim for graph instance embedding", embedding_dim)
print("initial batched pool graph dim is", assign_dim)
activation = F.relu
# initialize model
# 'diffpool' : diffpool
model = DiffPool(input_dim,
hidden_dim,
embedding_dim,
label_dim,
activation,
prog_args.gc_per_block,
prog_args.dropout,
prog_args.num_pool,
prog_args.linkpred,
prog_args.batch_size,
'meanpool',
assign_dim,
prog_args.pool_ratio)
if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset\
+ "/model.iter-" + str(prog_args.load_epoch)))
print("model init finished")
print("MODEL:::::::", prog_args.method)
if prog_args.cuda:
model = model.cuda()
logger = train(train_dataloader, model, prog_args, val_dataset=val_dataloader)
result = evaluate(test_dataloader, model, prog_args, logger)
print("test accuracy {}%".format(result*100))
def collate_fn(batch):
'''
collate_fn for dataset batching
transform ndata to tensor (in gpu is available)
'''
graphs, labels = map(list, zip(*batch))
#cuda = torch.cuda.is_available()
# batch graphs and cast to PyTorch tensor
for graph in graphs:
for (key, value) in graph.ndata.items():
graph.ndata[key] = torch.FloatTensor(value)
batched_graphs = dgl.batch(graphs)
# cast to PyTorch tensor
batched_labels = torch.LongTensor(np.array(labels))
return batched_graphs, batched_labels
def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
'''
training function
'''
dir = prog_args.save_dir + "/" + prog_args.dataset
if not os.path.exists(dir):
os.makedirs(dir)
dataloader = dataset
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
model.parameters()), lr=0.001)
early_stopping_logger = {"best_epoch":-1, "val_acc": -1}
if prog_args.cuda > 0:
torch.cuda.set_device(0)
for epoch in range(prog_args.epoch):
begin_time = time.time()
model.train()
accum_correct = 0
total = 0
print("EPOCH ###### {} ######".format(epoch))
computation_time = 0.0
for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
if torch.cuda.is_available():
for (key, value) in batch_graph.ndata.items():
batch_graph.ndata[key] = value.cuda()
graph_labels = graph_labels.cuda()
model.zero_grad()
compute_start = time.time()
ypred = model(batch_graph)
indi = torch.argmax(ypred, dim=1)
correct = torch.sum(indi == graph_labels).item()
accum_correct += correct
total += graph_labels.size()[0]
loss = model.loss(ypred, graph_labels)
loss.backward()
batch_compute_time = time.time() - compute_start
computation_time += batch_compute_time
nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
optimizer.step()
train_accu = accum_correct / total
print("train accuracy for this epoch {} is {}%".format(epoch,
train_accu*100))
elapsed_time = time.time() - begin_time
print("loss {} with epoch time {} s & computation time {} s ".format(loss.item(), elapsed_time, computation_time))
if val_dataset is not None:
result = evaluate(val_dataset, model, prog_args)
print("validation accuracy {}%".format(result*100))
if result >= early_stopping_logger['val_acc'] and result <=\
train_accu:
early_stopping_logger.update(best_epoch=epoch, val_acc=result)
if prog_args.save_dir is not None:
torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset\
+ "/model.iter-" + str(early_stopping_logger['best_epoch']))
print("best epoch is EPOCH {}, val_acc is {}%".format(early_stopping_logger['best_epoch'],
early_stopping_logger['val_acc']*100))
torch.cuda.empty_cache()
return early_stopping_logger
def evaluate(dataloader, model, prog_args, logger=None):
'''
evaluate function
'''
if logger is not None and prog_args.save_dir is not None:
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset\
+ "/model.iter-" + str(logger['best_epoch'])))
model.eval()
correct_label = 0
with torch.no_grad():
for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
if torch.cuda.is_available():
for (key, value) in batch_graph.ndata.items():
batch_graph.ndata[key] = value.cuda()
graph_labels = graph_labels.cuda()
ypred = model(batch_graph)
indi = torch.argmax(ypred, dim=1)
correct = torch.sum(indi==graph_labels)
correct_label += correct.item()
result = correct_label / (len(dataloader)*prog_args.batch_size)
return result
def main():
'''
main
'''
prog_args = arg_parse()
print(prog_args)
graph_classify_task(prog_args)
if __name__ == "__main__":
main()
...@@ -2,6 +2,7 @@ from __future__ import absolute_import ...@@ -2,6 +2,7 @@ from __future__ import absolute_import
import numpy as np import numpy as np
import dgl import dgl
import os import os
import random
from dgl.data.utils import download, extract_archive, get_download_dir from dgl.data.utils import download, extract_archive, get_download_dir
...@@ -21,13 +22,15 @@ class TUDataset(object): ...@@ -21,13 +22,15 @@ class TUDataset(object):
""" """
_url = r"https://ls11-www.cs.uni-dortmund.de/people/morris/graphkerneldatasets/{}.zip" _url = r"https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/{}.zip"
def __init__(self, name, use_pandas=False, hidden_size=10): def __init__(self, name, use_pandas=False, hidden_size=10, max_allow_node=None):
self.name = name self.name = name
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.extract_dir = self._download() self.extract_dir = self._download()
self.data_mode = None
self.max_allow_node = max_allow_node
if use_pandas: if use_pandas:
import pandas as pd import pandas as pd
...@@ -47,10 +50,15 @@ class TUDataset(object): ...@@ -47,10 +50,15 @@ class TUDataset(object):
g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1]) g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1])
node_idx_list = [] node_idx_list = []
self.max_num_node = 0
for idx in range(np.max(DS_indicator) + 1): for idx in range(np.max(DS_indicator) + 1):
node_idx = np.where(DS_indicator == idx) node_idx = np.where(DS_indicator == idx)
node_idx_list.append(node_idx[0]) node_idx_list.append(node_idx[0])
if len(node_idx[0]) > self.max_num_node:
self.max_num_node = len(node_idx[0])
self.graph_lists = g.subgraphs(node_idx_list) self.graph_lists = g.subgraphs(node_idx_list)
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = DS_graph_labels self.graph_labels = DS_graph_labels
try: try:
...@@ -60,6 +68,7 @@ class TUDataset(object): ...@@ -60,6 +68,7 @@ class TUDataset(object):
one_hot_node_labels = self._to_onehot(DS_node_labels) one_hot_node_labels = self._to_onehot(DS_node_labels)
for idxs, g in zip(node_idx_list, self.graph_lists): for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = one_hot_node_labels[idxs, :] g.ndata['feat'] = one_hot_node_labels[idxs, :]
self.data_mode = "node_label"
except IOError: except IOError:
print("No Node Label Data") print("No Node Label Data")
...@@ -67,13 +76,30 @@ class TUDataset(object): ...@@ -67,13 +76,30 @@ class TUDataset(object):
DS_node_attr = np.loadtxt(self._file_path("node_attributes"), delimiter=",") DS_node_attr = np.loadtxt(self._file_path("node_attributes"), delimiter=",")
for idxs, g in zip(node_idx_list, self.graph_lists): for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = DS_node_attr[idxs, :] g.ndata['feat'] = DS_node_attr[idxs, :]
self.data_mode = "node_attr"
except IOError: except IOError:
print("No Node Attribute Data") print("No Node Attribute Data")
if 'feat' not in g.ndata.keys(): if 'feat' not in g.ndata.keys():
for idxs, g in zip(node_idx_list, self.graph_lists): for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = np.ones((g.number_of_nodes(), hidden_size)) g.ndata['feat'] = np.ones((g.number_of_nodes(), hidden_size))
self.data_mode = "constant"
print("Use Constant one as Feature with hidden size {}".format(hidden_size)) print("Use Constant one as Feature with hidden size {}".format(hidden_size))
# remove graphs that are too large by user given standard
# optional pre-processing steop in conformity with Rex Ying's original
# DiffPool implementation
if self.max_allow_node:
preserve_idx = []
print("original dataset length : ", len(self.graph_lists))
for (i, g) in enumerate(self.graph_lists):
if g.number_of_nodes() <= self.max_allow_node:
preserve_idx.append(i)
self.graph_lists = [self.graph_lists[i] for i in preserve_idx]
print("after pruning graphs that are too big : ", len(self.graph_lists))
self.graph_labels = [self.graph_labels[i] for i in preserve_idx]
self.max_num_node = self.max_allow_node
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the i^th sample. """Get the i^th sample.
...@@ -117,5 +143,7 @@ class TUDataset(object): ...@@ -117,5 +143,7 @@ class TUDataset(object):
return one_hot_tensor return one_hot_tensor
def statistics(self): def statistics(self):
return self.graph_lists[0].ndata['feat'].shape[1] return self.graph_lists[0].ndata['feat'].shape[1],\
self.num_labels,\
self.max_num_node
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment