Commit bf3994ee authored by HQ's avatar HQ Committed by VoVAllen
Browse files

[Model][Hotfix] DiffPool formatting fix (#696)

* formatting

* formatting
parent 90e78c58
import numpy as np
import torch
def one_hotify(labels, pad=-1):
'''
cast label to one hot vector
'''
num_instances = len(labels)
if pad <= 0:
dim_embedding = np.max(labels) + 1 #zero-indexed assumed
else:
assert pad > 0, "result_dim for padding one hot embedding not set!"
dim_embedding = pad + 1
embeddings = np.zeros((num_instances, dim_embedding))
embeddings[np.arange(num_instances), labels] = 1
'''
cast label to one hot vector
'''
num_instances = len(labels)
if pad <= 0:
dim_embedding = np.max(labels) + 1 # zero-indexed assumed
else:
assert pad > 0, "result_dim for padding one hot embedding not set!"
dim_embedding = pad + 1
embeddings = np.zeros((num_instances, dim_embedding))
embeddings[np.arange(num_instances), labels] = 1
return embeddings
return embeddings
def pre_process(dataset, prog_args):
"""
diffpool specific data partition, pre-process and shuffling
"""
if prog_args.data_mode != "default":
print("overwrite node attributes with DiffPool's preprocess setting")
if prog_args.data_mode == 'id':
for g, _ in dataset:
id_list = np.arange(g.number_of_nodes())
g.ndata['feat'] = one_hotify(id_list, pad=dataset.max_num_node)
elif prog_args.data_mode == 'deg-num':
for g, _ in dataset:
g.ndata['feat'] = np.expand_dims(g.in_degrees(), axis=1)
elif prog_args.data_mode == 'deg':
for g in dataset:
degs = list(g.in_degrees())
degs_one_hot = one_hotify(degs, pad=dataset.max_degrees)
g.ndata['feat'] = degs_one_hot
\ No newline at end of file
"""
diffpool specific data partition, pre-process and shuffling
"""
if prog_args.data_mode != "default":
print("overwrite node attributes with DiffPool's preprocess setting")
if prog_args.data_mode == 'id':
for g, _ in dataset:
id_list = np.arange(g.number_of_nodes())
g.ndata['feat'] = one_hotify(id_list, pad=dataset.max_num_node)
elif prog_args.data_mode == 'deg-num':
for g, _ in dataset:
g.ndata['feat'] = np.expand_dims(g.in_degrees(), axis=1)
elif prog_args.data_mode == 'deg':
for g in dataset:
degs = list(g.in_degrees())
degs_one_hot = one_hotify(degs, pad=dataset.max_degrees)
g.ndata['feat'] = degs_one_hot
......@@ -10,6 +10,7 @@ class Aggregator(nn.Module):
This class is not supposed to be called
"""
def __init__(self):
super(Aggregator, self).__init__()
......@@ -22,10 +23,12 @@ class Aggregator(nn.Module):
# N x F
raise NotImplementedError
class MeanAggregator(Aggregator):
'''
Mean Aggregator for graphsage
'''
def __init__(self):
super(MeanAggregator, self).__init__()
......@@ -33,15 +36,17 @@ class MeanAggregator(Aggregator):
mean_neighbour = torch.mean(neighbour, dim=1)
return mean_neighbour
class MaxPoolAggregator(Aggregator):
'''
Maxpooling aggregator for graphsage
'''
def __init__(self, in_feats, out_feats, activation, bias):
super(MaxPoolAggregator, self).__init__()
self.linear = nn.Linear(in_feats, out_feats, bias=bias)
self.activation = activation
#Xavier initialization of weight
# Xavier initialization of weight
nn.init.xavier_uniform_(self.linear.weight,
gain=nn.init.calculate_gain('relu'))
......@@ -52,10 +57,12 @@ class MaxPoolAggregator(Aggregator):
maxpool_neighbour = torch.max(neighbour, dim=1)[0]
return maxpool_neighbour
class LSTMAggregator(Aggregator):
'''
LSTM aggregator for graphsage
'''
def __init__(self, in_feats, hidden_feats):
super(LSTMAggregator, self).__init__()
self.lstm = nn.LSTM(in_feats, hidden_feats, batch_first=True)
......@@ -65,7 +72,6 @@ class LSTMAggregator(Aggregator):
nn.init.xavier_uniform_(self.lstm.weight,
gain=nn.init.calculate_gain('relu'))
def init_hidden(self):
"""
Defaulted to initialite all zero
......@@ -82,11 +88,12 @@ class LSTMAggregator(Aggregator):
neighbours = neighbours[:, rand_order, :]
(lstm_out, self.hidden) = self.lstm(neighbours.view(neighbours.size()[0],
neighbours.size()[1],
-1))
neighbours.size()[
1],
-1))
return lstm_out[:, -1, :]
def forward(self, node):
neighbour = node.mailbox['m']
c = self.aggre(neighbour)
return {"c":c}
return {"c": c}
......@@ -3,22 +3,22 @@ import torch.nn as nn
import torch.nn.functional as F
class Bundler(nn.Module):
"""
Bundler, which will be the node_apply function in DGL paradigm
"""
def __init__(self, in_feats, out_feats, activation, dropout, bias=True):
super(Bundler, self).__init__()
self.dropout = nn.Dropout(p=dropout)
self.linear = nn.Linear(in_feats*2, out_feats, bias)
self.linear = nn.Linear(in_feats * 2, out_feats, bias)
self.activation = activation
nn.init.xavier_uniform_(self.linear.weight,
gain=nn.init.calculate_gain('relu'))
def concat(self, h, aggre_result):
bundle = torch.cat((h, aggre_result),1)
bundle = torch.cat((h, aggre_result), 1)
bundle = self.linear(bundle)
return bundle
......@@ -29,4 +29,4 @@ class Bundler(nn.Module):
bundle = F.normalize(bundle, p=2, dim=1)
if self.activation:
bundle = self.activation(bundle)
return {"h":bundle}
return {"h": bundle}
......@@ -16,6 +16,7 @@ class GraphSageLayer(nn.Module):
GraphSage layer in Inductive learning paper by hamilton
Here, graphsage layer is a reduced function in DGL framework
"""
def __init__(self, in_feats, out_feats, activation, dropout,
aggregator_type, bn=False, bias=True):
super(GraphSageLayer, self).__init__()
......@@ -50,19 +51,20 @@ class GraphSage(nn.Module):
"""
Grahpsage network that concatenate several graphsage layer
"""
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation,
dropout, aggregator_type):
super(GraphSage, self).__init__()
self.layers = nn.ModuleList()
#input layer
# input layer
self.layers.append(GraphSageLayer(in_feats, n_hidden, activation, dropout,
aggregator_type))
# hidden layers
for _ in range(n_layers -1):
for _ in range(n_layers - 1):
self.layers.append(GraphSageLayer(n_hidden, n_hidden, activation,
dropout, aggregator_type))
#output layer
# output layer
self.layers.append(GraphSageLayer(n_hidden, n_classes, None,
dropout, aggregator_type))
......@@ -72,50 +74,66 @@ class GraphSage(nn.Module):
h = layer(g, h)
return h
class DiffPoolBatchedGraphLayer(nn.Module):
def __init__(self, input_dim, assign_dim, output_feat_dim, activation, dropout, aggregator_type, link_pred):
def __init__(self, input_dim, assign_dim, output_feat_dim,
activation, dropout, aggregator_type, link_pred):
super(DiffPoolBatchedGraphLayer, self).__init__()
self.embedding_dim = input_dim
self.assign_dim = assign_dim
self.hidden_dim = output_feat_dim
self.link_pred = link_pred
self.feat_gc = GraphSageLayer(input_dim, output_feat_dim, activation, dropout, aggregator_type)
self.pool_gc = GraphSageLayer(input_dim, assign_dim, activation, dropout, aggregator_type)
self.feat_gc = GraphSageLayer(
input_dim,
output_feat_dim,
activation,
dropout,
aggregator_type)
self.pool_gc = GraphSageLayer(
input_dim,
assign_dim,
activation,
dropout,
aggregator_type)
self.reg_loss = nn.ModuleList([])
self.loss_log = {}
self.reg_loss.append(EntropyLoss())
def forward(self, g, h):
feat = self.feat_gc(g, h)
assign_tensor = self.pool_gc(g,h)
assign_tensor = self.pool_gc(g, h)
device = feat.device
assign_tensor_masks = []
batch_size = len(g.batch_num_nodes)
for g_n_nodes in g.batch_num_nodes:
mask =torch.ones((g_n_nodes,
int(assign_tensor.size()[1]/batch_size)))
mask = torch.ones((g_n_nodes,
int(assign_tensor.size()[1] / batch_size)))
assign_tensor_masks.append(mask)
"""
The first pooling layer is computed on batched graph.
The first pooling layer is computed on batched graph.
We first take the adjacency matrix of the batched graph, which is block-wise diagonal.
We then compute the assignment matrix for the whole batch graph, which will also be block diagonal
"""
mask = torch.FloatTensor(block_diag(*assign_tensor_masks)).to(device=device)
mask = torch.FloatTensor(
block_diag(
*
assign_tensor_masks)).to(
device=device)
assign_tensor = masked_softmax(assign_tensor, mask,
memory_efficient=False)
h = torch.matmul(torch.t(assign_tensor),feat)
memory_efficient=False)
h = torch.matmul(torch.t(assign_tensor), feat)
adj = g.adjacency_matrix(ctx=device)
adj_new = torch.sparse.mm(adj, assign_tensor)
adj_new = torch.mm(torch.t(assign_tensor), adj_new)
if self.link_pred:
current_lp_loss = torch.norm(adj.to_dense() -\
torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(),2)
current_lp_loss = torch.norm(adj.to_dense() -
torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(), 2)
self.loss_log['LinkPredLoss'] = current_lp_loss
for loss_layer in self.reg_loss:
loss_name = str(type(loss_layer).__name__)
self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor)
return adj_new, h
\ No newline at end of file
return adj_new, h
......@@ -13,14 +13,16 @@ from .tensorized_layers import *
from .model_utils import batch2tensor
import time
class DiffPool(nn.Module):
"""
DiffPool Fuse
"""
def __init__(self, input_dim, hidden_dim, embedding_dim,
label_dim, activation, n_layers, dropout,
n_pooling, linkpred, batch_size, aggregator_type,
assign_dim, pool_ratio, cat=False):
def __init__(self, input_dim, hidden_dim, embedding_dim,
label_dim, activation, n_layers, dropout,
n_pooling, linkpred, batch_size, aggregator_type,
assign_dim, pool_ratio, cat=False):
super(DiffPool, self).__init__()
self.link_pred = linkpred
self.concat = cat
......@@ -42,55 +44,93 @@ class DiffPool(nn.Module):
# constructing layers
# layers before diffpool
assert n_layers >= 3, "n_layers too few"
self.gc_before_pool.append(GraphSageLayer(input_dim, hidden_dim, activation, dropout, aggregator_type, self.bn))
self.gc_before_pool.append(
GraphSageLayer(
input_dim,
hidden_dim,
activation,
dropout,
aggregator_type,
self.bn))
for _ in range(n_layers - 2):
self.gc_before_pool.append(GraphSageLayer(hidden_dim, hidden_dim, activation, dropout, aggregator_type, self.bn))
self.gc_before_pool.append(GraphSageLayer(hidden_dim, embedding_dim, None, dropout, aggregator_type))
self.gc_before_pool.append(
GraphSageLayer(
hidden_dim,
hidden_dim,
activation,
dropout,
aggregator_type,
self.bn))
self.gc_before_pool.append(
GraphSageLayer(
hidden_dim,
embedding_dim,
None,
dropout,
aggregator_type))
assign_dims = []
assign_dims.append(self.assign_dim)
if self.concat:
# diffpool layer receive pool_emedding_dim node feature tensor
# and return pool_embedding_dim node embedding
pool_embedding_dim = hidden_dim * (n_layers -1) + embedding_dim
pool_embedding_dim = hidden_dim * (n_layers - 1) + embedding_dim
else:
pool_embedding_dim = embedding_dim
self.first_diffpool_layer = DiffPoolBatchedGraphLayer(pool_embedding_dim, self.assign_dim, hidden_dim,activation, dropout, aggregator_type, self.link_pred)
self.first_diffpool_layer = DiffPoolBatchedGraphLayer(
pool_embedding_dim,
self.assign_dim,
hidden_dim,
activation,
dropout,
aggregator_type,
self.link_pred)
gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1):
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, hidden_dim))
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, embedding_dim))
self.gc_after_pool.append(gc_after_per_pool)
self.assign_dim = int(self.assign_dim * pool_ratio)
# each pooling module
for _ in range(n_pooling-1):
self.diffpool_layers.append(BatchedDiffPool(pool_embedding_dim, self.assign_dim, hidden_dim, self.link_pred))
for _ in range(n_pooling - 1):
self.diffpool_layers.append(
BatchedDiffPool(
pool_embedding_dim,
self.assign_dim,
hidden_dim,
self.link_pred))
gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1):
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, hidden_dim))
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, embedding_dim))
gc_after_per_pool.append(
BatchedGraphSAGE(
hidden_dim, hidden_dim))
gc_after_per_pool.append(
BatchedGraphSAGE(
hidden_dim, embedding_dim))
self.gc_after_pool.append(gc_after_per_pool)
assign_dims.append(self.assign_dim)
self.assign_dim = int(self.assign_dim * pool_ratio)
# predicting layer
if self.concat:
self.pred_input_dim = pool_embedding_dim*self.num_aggs*(n_pooling+1)
self.pred_input_dim = pool_embedding_dim * \
self.num_aggs * (n_pooling + 1)
else:
self.pred_input_dim = embedding_dim*self.num_aggs
self.pred_input_dim = embedding_dim * self.num_aggs
self.pred_layer = nn.Linear(self.pred_input_dim, label_dim)
# weight initialization
for m in self.modules():
if isinstance(m, nn.Linear):
m.weight.data = init.xavier_uniform_(m.weight.data,
gain=nn.init.calculate_gain('relu'))
gain=nn.init.calculate_gain('relu'))
if m.bias is not None:
m.bias.data = init.constant_(m.bias.data, 0.0)
def gcn_forward(self, g, h, gc_layers, cat=False):
"""
Return gc_layer embedding cat.
......@@ -102,22 +142,22 @@ class DiffPool(nn.Module):
h = gc_layers[-1](g, h)
block_readout.append(h)
if cat:
block = torch.cat(block_readout, dim=1) # N x F, F = F1 + F2 + ...
block = torch.cat(block_readout, dim=1) # N x F, F = F1 + F2 + ...
else:
block = h
return block
def gcn_forward_tensorized(self, h, adj, gc_layers, cat=False):
block_readout = []
for gc_layer in gc_layers:
h = gc_layer(h, adj)
block_readout.append(h)
if cat:
block = torch.cat(block_readout, dim=2) # N x F, F = F1 + F2 + ...
block = torch.cat(block_readout, dim=2) # N x F, F = F1 + F2 + ...
else:
block = h
return block
def forward(self, g):
self.link_pred_loss = []
self.entropy_loss = []
......@@ -138,22 +178,23 @@ class DiffPool(nn.Module):
if self.num_aggs == 2:
readout = dgl.max_nodes(g, 'h')
out_all.append(readout)
adj, h = self.first_diffpool_layer(g, g_embedding)
node_per_pool_graph = int(adj.size()[0] / self.batch_size)
h, adj = batch2tensor(adj, h, node_per_pool_graph)
h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[0], self.concat)
h = self.gcn_forward_tensorized(
h, adj, self.gc_after_pool[0], self.concat)
readout = torch.sum(h, dim=1)
out_all.append(readout)
if self.num_aggs == 2:
readout, _ = torch.max(h, dim=1)
out_all.append(readout)
for i, diffpool_layer in enumerate(self.diffpool_layers):
h, adj = diffpool_layer(h, adj)
h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[i+1], self.concat)
h = self.gcn_forward_tensorized(
h, adj, self.gc_after_pool[i + 1], self.concat)
readout = torch.sum(h, dim=1)
out_all.append(readout)
if self.num_aggs == 2:
......@@ -165,7 +206,7 @@ class DiffPool(nn.Module):
final_readout = readout
ypred = self.pred_layer(final_readout)
return ypred
def loss(self, pred, label):
'''
loss function
......@@ -175,5 +216,5 @@ class DiffPool(nn.Module):
loss = criterion(pred, label)
for diffpool_layer in self.diffpool_layers:
for key, value in diffpool_layer.loss_log.items():
loss += value
loss += value
return loss
import torch
import torch.nn as nn
class EntropyLoss(nn.Module):
# Return Scalar
def forward(self, adj, anext, s_l):
entropy = (torch.distributions.Categorical(probs=s_l).entropy()).sum(-1).mean(-1)
entropy = (torch.distributions.Categorical(
probs=s_l).entropy()).sum(-1).mean(-1)
assert not torch.isnan(entropy)
return entropy
......@@ -12,7 +14,7 @@ class EntropyLoss(nn.Module):
class LinkPredLoss(nn.Module):
def forward(self, adj, anext, s_l):
link_pred_loss = (adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2))
link_pred_loss = (
adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2))
link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))
return link_pred_loss.mean()
import torch as th
from torch.autograd import Function
def batch2tensor(batch_adj, batch_feat, node_per_pool_graph):
"""
transform a batched graph to batched adjacency tensor and node feature tensor
......@@ -9,13 +10,13 @@ def batch2tensor(batch_adj, batch_feat, node_per_pool_graph):
adj_list = []
feat_list = []
for i in range(batch_size):
start = i*node_per_pool_graph
end = (i+1)*node_per_pool_graph
adj_list.append(batch_adj[start:end,start:end])
feat_list.append(batch_feat[start:end,:])
adj_list = list(map(lambda x : th.unsqueeze(x, 0), adj_list))
feat_list = list(map(lambda x : th.unsqueeze(x, 0), feat_list))
adj = th.cat(adj_list,dim=0)
start = i * node_per_pool_graph
end = (i + 1) * node_per_pool_graph
adj_list.append(batch_adj[start:end, start:end])
feat_list.append(batch_feat[start:end, :])
adj_list = list(map(lambda x: th.unsqueeze(x, 0), adj_list))
feat_list = list(map(lambda x: th.unsqueeze(x, 0), feat_list))
adj = th.cat(adj_list, dim=0)
feat = th.cat(feat_list, dim=0)
return feat, adj
......@@ -38,10 +39,7 @@ def masked_softmax(matrix, mask, dim=-1, memory_efficient=True,
result = result * mask
result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
else:
masked_matrix = matrix.masked_fill((1-mask).byte(),
masked_matrix = matrix.masked_fill((1 - mask).byte(),
mask_fill_value)
result = th.nn.functional.softmax(masked_matrix, dim=dim)
return result
......@@ -5,6 +5,7 @@ from model.tensorized_layers.assignment import DiffPoolAssignment
from model.tensorized_layers.graphsage import BatchedGraphSAGE
from model.loss import EntropyLoss, LinkPredLoss
class BatchedDiffPool(nn.Module):
def __init__(self, nfeat, nnext, nhid, link_pred=False, entropy=True):
super(BatchedDiffPool, self).__init__()
......@@ -20,7 +21,6 @@ class BatchedDiffPool(nn.Module):
if entropy:
self.reg_loss.append(EntropyLoss())
def forward(self, x, adj, log=False):
z_l = self.embed(x, adj)
s_l = self.assign(x, adj)
......@@ -35,5 +35,3 @@ class BatchedDiffPool(nn.Module):
if log:
self.log['a'] = anext.cpu().numpy()
return xnext, anext
......@@ -4,14 +4,17 @@ from torch.nn import functional as F
class BatchedGraphSAGE(nn.Module):
def __init__(self, infeat, outfeat, use_bn=True, mean=False, add_self=False):
def __init__(self, infeat, outfeat, use_bn=True,
mean=False, add_self=False):
super().__init__()
self.add_self = add_self
self.use_bn = use_bn
self.mean = mean
self.W = nn.Linear(infeat, outfeat, bias=True)
nn.init.xavier_uniform_(self.W.weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.W.weight,
gain=nn.init.calculate_gain('relu'))
def forward(self, x, adj):
if self.use_bn and not hasattr(self, 'bn'):
......
......@@ -17,20 +17,37 @@ from dgl.data import tu
from model.encoder import DiffPool
from data_utils import pre_process
def arg_parse():
'''
argument parser
'''
parser = argparse.ArgumentParser(description='DiffPool arguments')
parser.add_argument('--dataset', dest='dataset', help='Input Dataset')
parser.add_argument('--pool_ratio', dest='pool_ratio', type=float, help='pooling ratio')
parser.add_argument('--num_pool', dest='num_pool', type=int, help='num_pooling layer')
parser.add_argument(
'--pool_ratio',
dest='pool_ratio',
type=float,
help='pooling ratio')
parser.add_argument(
'--num_pool',
dest='num_pool',
type=int,
help='num_pooling layer')
parser.add_argument('--no_link_pred', dest='linkpred', action='store_false',
help='switch of link prediction object')
parser.add_argument('--cuda', dest='cuda', type=int, help='switch cuda')
parser.add_argument('--lr', dest='lr', type=float, help='learning rate')
parser.add_argument('--clip', dest='clip', type=float, help='gradient clipping')
parser.add_argument('--batch-size', dest='batch_size', type=int, help='batch size')
parser.add_argument(
'--clip',
dest='clip',
type=float,
help='gradient clipping')
parser.add_argument(
'--batch-size',
dest='batch_size',
type=int,
help='batch size')
parser.add_argument('--epochs', dest='epoch', type=int,
help='num-of-epoch')
parser.add_argument('--train-ratio', dest='train_ratio', type=float,
......@@ -47,13 +64,16 @@ def arg_parse():
help='dropout rate')
parser.add_argument('--bias', dest='bias', action='store_const',
const=True, default=True, help='switch for bias')
parser.add_argument('--save_dir', dest='save_dir', help='model saving directory: SAVE_DICT/DATASET')
parser.add_argument(
'--save_dir',
dest='save_dir',
help='model saving directory: SAVE_DICT/DATASET')
parser.add_argument('--load_epoch', dest='load_epoch', help='load trained model params from\
SAVE_DICT/DATASET/model-LOAD_EPOCH')
parser.add_argument('--data_mode', dest='data_mode', help='data\
preprocessing mode: default, id, degree, or one-hot\
vector of degree number', choices=['default', 'id', 'deg',
'deg_num'])
'deg_num'])
parser.set_defaults(dataset='ENZYMES',
pool_ratio=0.15,
......@@ -73,9 +93,10 @@ def arg_parse():
bias=True,
save_dir="./model_param",
load_epoch=-1,
data_mode = 'default')
data_mode='default')
return parser.parse_args()
def prepare_data(dataset, prog_args, train=False, pre_process=None):
'''
preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
......@@ -84,11 +105,11 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None):
shuffle = True
else:
shuffle = False
if pre_process:
pre_process(dataset, prog_args)
#dataset.set_fold(fold)
# dataset.set_fold(fold)
return torch.utils.data.DataLoader(dataset,
batch_size=prog_args.batch_size,
shuffle=shuffle,
......@@ -101,13 +122,14 @@ def graph_classify_task(prog_args):
'''
perform graph classification task
'''
dataset = tu.TUDataset(name=prog_args.dataset)
train_size = int(prog_args.train_ratio * len(dataset))
test_size = int(prog_args.test_ratio * len(dataset))
val_size = int(len(dataset) - train_size - test_size)
val_size = int(len(dataset) - train_size - test_size)
dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(dataset, (train_size, val_size, test_size))
dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(
dataset, (train_size, val_size, test_size))
train_dataloader = prepare_data(dataset_train, prog_args, train=True,
pre_process=pre_process)
val_dataloader = prepare_data(dataset_val, prog_args, train=False,
......@@ -122,47 +144,52 @@ def graph_classify_task(prog_args):
print("number of graphs is", len(dataset))
# assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size"
hidden_dim = 64 # used to be 64
hidden_dim = 64 # used to be 64
embedding_dim = 64
# calculate assignment dimension: pool_ratio * largest graph's maximum
# number of nodes in the dataset
assign_dim = int(max_num_node * prog_args.pool_ratio) * prog_args.batch_size
assign_dim = int(max_num_node * prog_args.pool_ratio) * \
prog_args.batch_size
print("++++++++++MODEL STATISTICS++++++++")
print("model hidden dim is", hidden_dim)
print("model embedding dim for graph instance embedding", embedding_dim)
print("initial batched pool graph dim is", assign_dim)
activation = F.relu
# initialize model
# 'diffpool' : diffpool
model = DiffPool(input_dim,
hidden_dim,
embedding_dim,
model = DiffPool(input_dim,
hidden_dim,
embedding_dim,
label_dim,
activation,
prog_args.gc_per_block,
prog_args.dropout,
prog_args.num_pool,
activation,
prog_args.gc_per_block,
prog_args.dropout,
prog_args.num_pool,
prog_args.linkpred,
prog_args.batch_size,
'meanpool',
prog_args.batch_size,
'meanpool',
assign_dim,
prog_args.pool_ratio)
if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset\
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
+ "/model.iter-" + str(prog_args.load_epoch)))
print("model init finished")
print("MODEL:::::::", prog_args.method)
if prog_args.cuda:
model = model.cuda()
logger = train(train_dataloader, model, prog_args, val_dataset=val_dataloader)
logger = train(
train_dataloader,
model,
prog_args,
val_dataset=val_dataloader)
result = evaluate(test_dataloader, model, prog_args, logger)
print("test accuracy {}%".format(result*100))
print("test accuracy {}%".format(result * 100))
def collate_fn(batch):
'''
......@@ -183,6 +210,7 @@ def collate_fn(batch):
return batched_graphs, batched_labels
def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
'''
training function
......@@ -193,7 +221,7 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
dataloader = dataset
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
model.parameters()), lr=0.001)
early_stopping_logger = {"best_epoch":-1, "val_acc": -1}
early_stopping_logger = {"best_epoch": -1, "val_acc": -1}
if prog_args.cuda > 0:
torch.cuda.set_device(0)
......@@ -210,7 +238,6 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
batch_graph.ndata[key] = value.cuda()
graph_labels = graph_labels.cuda()
model.zero_grad()
compute_start = time.time()
ypred = model(batch_graph)
......@@ -225,32 +252,33 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
optimizer.step()
train_accu = accum_correct / total
print("train accuracy for this epoch {} is {}%".format(epoch,
train_accu*100))
train_accu * 100))
elapsed_time = time.time() - begin_time
print("loss {} with epoch time {} s & computation time {} s ".format(loss.item(), elapsed_time, computation_time))
print("loss {} with epoch time {} s & computation time {} s ".format(
loss.item(), elapsed_time, computation_time))
if val_dataset is not None:
result = evaluate(val_dataset, model, prog_args)
print("validation accuracy {}%".format(result*100))
print("validation accuracy {}%".format(result * 100))
if result >= early_stopping_logger['val_acc'] and result <=\
train_accu:
train_accu:
early_stopping_logger.update(best_epoch=epoch, val_acc=result)
if prog_args.save_dir is not None:
torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset\
+ "/model.iter-" + str(early_stopping_logger['best_epoch']))
print("best epoch is EPOCH {}, val_acc is {}%".format(early_stopping_logger['best_epoch'],
early_stopping_logger['val_acc']*100))
torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset
+ "/model.iter-" + str(early_stopping_logger['best_epoch']))
print("best epoch is EPOCH {}, val_acc is {}%".format(early_stopping_logger['best_epoch'],
early_stopping_logger['val_acc'] * 100))
torch.cuda.empty_cache()
return early_stopping_logger
def evaluate(dataloader, model, prog_args, logger=None):
'''
evaluate function
'''
if logger is not None and prog_args.save_dir is not None:
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset\
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
+ "/model.iter-" + str(logger['best_epoch'])))
model.eval()
correct_label = 0
......@@ -262,11 +290,12 @@ def evaluate(dataloader, model, prog_args, logger=None):
graph_labels = graph_labels.cuda()
ypred = model(batch_graph)
indi = torch.argmax(ypred, dim=1)
correct = torch.sum(indi==graph_labels)
correct = torch.sum(indi == graph_labels)
correct_label += correct.item()
result = correct_label / (len(dataloader)*prog_args.batch_size)
result = correct_label / (len(dataloader) * prog_args.batch_size)
return result
def main():
'''
main
......
......@@ -24,7 +24,8 @@ class TUDataset(object):
_url = r"https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/{}.zip"
def __init__(self, name, use_pandas=False, hidden_size=10, max_allow_node=None):
def __init__(self, name, use_pandas=False,
hidden_size=10, max_allow_node=None):
self.name = name
self.hidden_size = hidden_size
......@@ -73,7 +74,8 @@ class TUDataset(object):
print("No Node Label Data")
try:
DS_node_attr = np.loadtxt(self._file_path("node_attributes"), delimiter=",")
DS_node_attr = np.loadtxt(
self._file_path("node_attributes"), delimiter=",")
for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = DS_node_attr[idxs, :]
self.data_mode = "node_attr"
......@@ -84,8 +86,9 @@ class TUDataset(object):
for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = np.ones((g.number_of_nodes(), hidden_size))
self.data_mode = "constant"
print("Use Constant one as Feature with hidden size {}".format(hidden_size))
print(
"Use Constant one as Feature with hidden size {}".format(hidden_size))
# remove graphs that are too large by user given standard
# optional pre-processing steop in conformity with Rex Ying's original
# DiffPool implementation
......@@ -96,11 +99,12 @@ class TUDataset(object):
if g.number_of_nodes() <= self.max_allow_node:
preserve_idx.append(i)
self.graph_lists = [self.graph_lists[i] for i in preserve_idx]
print("after pruning graphs that are too big : ", len(self.graph_lists))
print(
"after pruning graphs that are too big : ", len(
self.graph_lists))
self.graph_labels = [self.graph_labels[i] for i in preserve_idx]
self.max_num_node = self.max_allow_node
def __getitem__(self, idx):
"""Get the i^th sample.
Paramters
......@@ -121,14 +125,18 @@ class TUDataset(object):
def _download(self):
download_dir = get_download_dir()
zip_file_path = os.path.join(download_dir, "tu_{}.zip".format(self.name))
zip_file_path = os.path.join(
download_dir,
"tu_{}.zip".format(
self.name))
download(self._url.format(self.name), path=zip_file_path)
extract_dir = os.path.join(download_dir, "tu_{}".format(self.name))
extract_archive(zip_file_path, extract_dir)
return extract_dir
def _file_path(self, category):
return os.path.join(self.extract_dir, self.name, "{}_{}.txt".format(self.name, category))
return os.path.join(self.extract_dir, self.name,
"{}_{}.txt".format(self.name, category))
@staticmethod
def _idx_from_zero(idx_tensor):
......@@ -144,6 +152,5 @@ class TUDataset(object):
def statistics(self):
return self.graph_lists[0].ndata['feat'].shape[1],\
self.num_labels,\
self.max_num_node
self.num_labels,\
self.max_num_node
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment