Commit bf3994ee authored by HQ's avatar HQ Committed by VoVAllen
Browse files

[Model][Hotfix] DiffPool formatting fix (#696)

* formatting

* formatting
parent 90e78c58
import numpy as np import numpy as np
import torch import torch
def one_hotify(labels, pad=-1): def one_hotify(labels, pad=-1):
''' '''
cast label to one hot vector cast label to one hot vector
''' '''
num_instances = len(labels) num_instances = len(labels)
if pad <= 0: if pad <= 0:
dim_embedding = np.max(labels) + 1 #zero-indexed assumed dim_embedding = np.max(labels) + 1 # zero-indexed assumed
else: else:
assert pad > 0, "result_dim for padding one hot embedding not set!" assert pad > 0, "result_dim for padding one hot embedding not set!"
dim_embedding = pad + 1 dim_embedding = pad + 1
embeddings = np.zeros((num_instances, dim_embedding)) embeddings = np.zeros((num_instances, dim_embedding))
embeddings[np.arange(num_instances), labels] = 1 embeddings[np.arange(num_instances), labels] = 1
return embeddings return embeddings
def pre_process(dataset, prog_args): def pre_process(dataset, prog_args):
""" """
diffpool specific data partition, pre-process and shuffling diffpool specific data partition, pre-process and shuffling
""" """
if prog_args.data_mode != "default": if prog_args.data_mode != "default":
print("overwrite node attributes with DiffPool's preprocess setting") print("overwrite node attributes with DiffPool's preprocess setting")
if prog_args.data_mode == 'id': if prog_args.data_mode == 'id':
for g, _ in dataset: for g, _ in dataset:
id_list = np.arange(g.number_of_nodes()) id_list = np.arange(g.number_of_nodes())
g.ndata['feat'] = one_hotify(id_list, pad=dataset.max_num_node) g.ndata['feat'] = one_hotify(id_list, pad=dataset.max_num_node)
elif prog_args.data_mode == 'deg-num': elif prog_args.data_mode == 'deg-num':
for g, _ in dataset: for g, _ in dataset:
g.ndata['feat'] = np.expand_dims(g.in_degrees(), axis=1) g.ndata['feat'] = np.expand_dims(g.in_degrees(), axis=1)
elif prog_args.data_mode == 'deg': elif prog_args.data_mode == 'deg':
for g in dataset: for g in dataset:
degs = list(g.in_degrees()) degs = list(g.in_degrees())
degs_one_hot = one_hotify(degs, pad=dataset.max_degrees) degs_one_hot = one_hotify(degs, pad=dataset.max_degrees)
g.ndata['feat'] = degs_one_hot g.ndata['feat'] = degs_one_hot
\ No newline at end of file
...@@ -10,6 +10,7 @@ class Aggregator(nn.Module): ...@@ -10,6 +10,7 @@ class Aggregator(nn.Module):
This class is not supposed to be called This class is not supposed to be called
""" """
def __init__(self): def __init__(self):
super(Aggregator, self).__init__() super(Aggregator, self).__init__()
...@@ -22,10 +23,12 @@ class Aggregator(nn.Module): ...@@ -22,10 +23,12 @@ class Aggregator(nn.Module):
# N x F # N x F
raise NotImplementedError raise NotImplementedError
class MeanAggregator(Aggregator): class MeanAggregator(Aggregator):
''' '''
Mean Aggregator for graphsage Mean Aggregator for graphsage
''' '''
def __init__(self): def __init__(self):
super(MeanAggregator, self).__init__() super(MeanAggregator, self).__init__()
...@@ -33,15 +36,17 @@ class MeanAggregator(Aggregator): ...@@ -33,15 +36,17 @@ class MeanAggregator(Aggregator):
mean_neighbour = torch.mean(neighbour, dim=1) mean_neighbour = torch.mean(neighbour, dim=1)
return mean_neighbour return mean_neighbour
class MaxPoolAggregator(Aggregator): class MaxPoolAggregator(Aggregator):
''' '''
Maxpooling aggregator for graphsage Maxpooling aggregator for graphsage
''' '''
def __init__(self, in_feats, out_feats, activation, bias): def __init__(self, in_feats, out_feats, activation, bias):
super(MaxPoolAggregator, self).__init__() super(MaxPoolAggregator, self).__init__()
self.linear = nn.Linear(in_feats, out_feats, bias=bias) self.linear = nn.Linear(in_feats, out_feats, bias=bias)
self.activation = activation self.activation = activation
#Xavier initialization of weight # Xavier initialization of weight
nn.init.xavier_uniform_(self.linear.weight, nn.init.xavier_uniform_(self.linear.weight,
gain=nn.init.calculate_gain('relu')) gain=nn.init.calculate_gain('relu'))
...@@ -52,10 +57,12 @@ class MaxPoolAggregator(Aggregator): ...@@ -52,10 +57,12 @@ class MaxPoolAggregator(Aggregator):
maxpool_neighbour = torch.max(neighbour, dim=1)[0] maxpool_neighbour = torch.max(neighbour, dim=1)[0]
return maxpool_neighbour return maxpool_neighbour
class LSTMAggregator(Aggregator): class LSTMAggregator(Aggregator):
''' '''
LSTM aggregator for graphsage LSTM aggregator for graphsage
''' '''
def __init__(self, in_feats, hidden_feats): def __init__(self, in_feats, hidden_feats):
super(LSTMAggregator, self).__init__() super(LSTMAggregator, self).__init__()
self.lstm = nn.LSTM(in_feats, hidden_feats, batch_first=True) self.lstm = nn.LSTM(in_feats, hidden_feats, batch_first=True)
...@@ -65,7 +72,6 @@ class LSTMAggregator(Aggregator): ...@@ -65,7 +72,6 @@ class LSTMAggregator(Aggregator):
nn.init.xavier_uniform_(self.lstm.weight, nn.init.xavier_uniform_(self.lstm.weight,
gain=nn.init.calculate_gain('relu')) gain=nn.init.calculate_gain('relu'))
def init_hidden(self): def init_hidden(self):
""" """
Defaulted to initialite all zero Defaulted to initialite all zero
...@@ -82,11 +88,12 @@ class LSTMAggregator(Aggregator): ...@@ -82,11 +88,12 @@ class LSTMAggregator(Aggregator):
neighbours = neighbours[:, rand_order, :] neighbours = neighbours[:, rand_order, :]
(lstm_out, self.hidden) = self.lstm(neighbours.view(neighbours.size()[0], (lstm_out, self.hidden) = self.lstm(neighbours.view(neighbours.size()[0],
neighbours.size()[1], neighbours.size()[
-1)) 1],
-1))
return lstm_out[:, -1, :] return lstm_out[:, -1, :]
def forward(self, node): def forward(self, node):
neighbour = node.mailbox['m'] neighbour = node.mailbox['m']
c = self.aggre(neighbour) c = self.aggre(neighbour)
return {"c":c} return {"c": c}
...@@ -3,22 +3,22 @@ import torch.nn as nn ...@@ -3,22 +3,22 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class Bundler(nn.Module): class Bundler(nn.Module):
""" """
Bundler, which will be the node_apply function in DGL paradigm Bundler, which will be the node_apply function in DGL paradigm
""" """
def __init__(self, in_feats, out_feats, activation, dropout, bias=True): def __init__(self, in_feats, out_feats, activation, dropout, bias=True):
super(Bundler, self).__init__() super(Bundler, self).__init__()
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
self.linear = nn.Linear(in_feats*2, out_feats, bias) self.linear = nn.Linear(in_feats * 2, out_feats, bias)
self.activation = activation self.activation = activation
nn.init.xavier_uniform_(self.linear.weight, nn.init.xavier_uniform_(self.linear.weight,
gain=nn.init.calculate_gain('relu')) gain=nn.init.calculate_gain('relu'))
def concat(self, h, aggre_result): def concat(self, h, aggre_result):
bundle = torch.cat((h, aggre_result),1) bundle = torch.cat((h, aggre_result), 1)
bundle = self.linear(bundle) bundle = self.linear(bundle)
return bundle return bundle
...@@ -29,4 +29,4 @@ class Bundler(nn.Module): ...@@ -29,4 +29,4 @@ class Bundler(nn.Module):
bundle = F.normalize(bundle, p=2, dim=1) bundle = F.normalize(bundle, p=2, dim=1)
if self.activation: if self.activation:
bundle = self.activation(bundle) bundle = self.activation(bundle)
return {"h":bundle} return {"h": bundle}
...@@ -16,6 +16,7 @@ class GraphSageLayer(nn.Module): ...@@ -16,6 +16,7 @@ class GraphSageLayer(nn.Module):
GraphSage layer in Inductive learning paper by hamilton GraphSage layer in Inductive learning paper by hamilton
Here, graphsage layer is a reduced function in DGL framework Here, graphsage layer is a reduced function in DGL framework
""" """
def __init__(self, in_feats, out_feats, activation, dropout, def __init__(self, in_feats, out_feats, activation, dropout,
aggregator_type, bn=False, bias=True): aggregator_type, bn=False, bias=True):
super(GraphSageLayer, self).__init__() super(GraphSageLayer, self).__init__()
...@@ -50,19 +51,20 @@ class GraphSage(nn.Module): ...@@ -50,19 +51,20 @@ class GraphSage(nn.Module):
""" """
Grahpsage network that concatenate several graphsage layer Grahpsage network that concatenate several graphsage layer
""" """
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation,
dropout, aggregator_type): dropout, aggregator_type):
super(GraphSage, self).__init__() super(GraphSage, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
#input layer # input layer
self.layers.append(GraphSageLayer(in_feats, n_hidden, activation, dropout, self.layers.append(GraphSageLayer(in_feats, n_hidden, activation, dropout,
aggregator_type)) aggregator_type))
# hidden layers # hidden layers
for _ in range(n_layers -1): for _ in range(n_layers - 1):
self.layers.append(GraphSageLayer(n_hidden, n_hidden, activation, self.layers.append(GraphSageLayer(n_hidden, n_hidden, activation,
dropout, aggregator_type)) dropout, aggregator_type))
#output layer # output layer
self.layers.append(GraphSageLayer(n_hidden, n_classes, None, self.layers.append(GraphSageLayer(n_hidden, n_classes, None,
dropout, aggregator_type)) dropout, aggregator_type))
...@@ -72,50 +74,66 @@ class GraphSage(nn.Module): ...@@ -72,50 +74,66 @@ class GraphSage(nn.Module):
h = layer(g, h) h = layer(g, h)
return h return h
class DiffPoolBatchedGraphLayer(nn.Module): class DiffPoolBatchedGraphLayer(nn.Module):
def __init__(self, input_dim, assign_dim, output_feat_dim, activation, dropout, aggregator_type, link_pred): def __init__(self, input_dim, assign_dim, output_feat_dim,
activation, dropout, aggregator_type, link_pred):
super(DiffPoolBatchedGraphLayer, self).__init__() super(DiffPoolBatchedGraphLayer, self).__init__()
self.embedding_dim = input_dim self.embedding_dim = input_dim
self.assign_dim = assign_dim self.assign_dim = assign_dim
self.hidden_dim = output_feat_dim self.hidden_dim = output_feat_dim
self.link_pred = link_pred self.link_pred = link_pred
self.feat_gc = GraphSageLayer(input_dim, output_feat_dim, activation, dropout, aggregator_type) self.feat_gc = GraphSageLayer(
self.pool_gc = GraphSageLayer(input_dim, assign_dim, activation, dropout, aggregator_type) input_dim,
output_feat_dim,
activation,
dropout,
aggregator_type)
self.pool_gc = GraphSageLayer(
input_dim,
assign_dim,
activation,
dropout,
aggregator_type)
self.reg_loss = nn.ModuleList([]) self.reg_loss = nn.ModuleList([])
self.loss_log = {} self.loss_log = {}
self.reg_loss.append(EntropyLoss()) self.reg_loss.append(EntropyLoss())
def forward(self, g, h): def forward(self, g, h):
feat = self.feat_gc(g, h) feat = self.feat_gc(g, h)
assign_tensor = self.pool_gc(g,h) assign_tensor = self.pool_gc(g, h)
device = feat.device device = feat.device
assign_tensor_masks = [] assign_tensor_masks = []
batch_size = len(g.batch_num_nodes) batch_size = len(g.batch_num_nodes)
for g_n_nodes in g.batch_num_nodes: for g_n_nodes in g.batch_num_nodes:
mask =torch.ones((g_n_nodes, mask = torch.ones((g_n_nodes,
int(assign_tensor.size()[1]/batch_size))) int(assign_tensor.size()[1] / batch_size)))
assign_tensor_masks.append(mask) assign_tensor_masks.append(mask)
""" """
The first pooling layer is computed on batched graph. The first pooling layer is computed on batched graph.
We first take the adjacency matrix of the batched graph, which is block-wise diagonal. We first take the adjacency matrix of the batched graph, which is block-wise diagonal.
We then compute the assignment matrix for the whole batch graph, which will also be block diagonal We then compute the assignment matrix for the whole batch graph, which will also be block diagonal
""" """
mask = torch.FloatTensor(block_diag(*assign_tensor_masks)).to(device=device) mask = torch.FloatTensor(
block_diag(
*
assign_tensor_masks)).to(
device=device)
assign_tensor = masked_softmax(assign_tensor, mask, assign_tensor = masked_softmax(assign_tensor, mask,
memory_efficient=False) memory_efficient=False)
h = torch.matmul(torch.t(assign_tensor),feat) h = torch.matmul(torch.t(assign_tensor), feat)
adj = g.adjacency_matrix(ctx=device) adj = g.adjacency_matrix(ctx=device)
adj_new = torch.sparse.mm(adj, assign_tensor) adj_new = torch.sparse.mm(adj, assign_tensor)
adj_new = torch.mm(torch.t(assign_tensor), adj_new) adj_new = torch.mm(torch.t(assign_tensor), adj_new)
if self.link_pred: if self.link_pred:
current_lp_loss = torch.norm(adj.to_dense() -\ current_lp_loss = torch.norm(adj.to_dense() -
torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(),2) torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(), 2)
self.loss_log['LinkPredLoss'] = current_lp_loss self.loss_log['LinkPredLoss'] = current_lp_loss
for loss_layer in self.reg_loss: for loss_layer in self.reg_loss:
loss_name = str(type(loss_layer).__name__) loss_name = str(type(loss_layer).__name__)
self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor) self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor)
return adj_new, h return adj_new, h
\ No newline at end of file
...@@ -13,14 +13,16 @@ from .tensorized_layers import * ...@@ -13,14 +13,16 @@ from .tensorized_layers import *
from .model_utils import batch2tensor from .model_utils import batch2tensor
import time import time
class DiffPool(nn.Module): class DiffPool(nn.Module):
""" """
DiffPool Fuse DiffPool Fuse
""" """
def __init__(self, input_dim, hidden_dim, embedding_dim,
label_dim, activation, n_layers, dropout, def __init__(self, input_dim, hidden_dim, embedding_dim,
n_pooling, linkpred, batch_size, aggregator_type, label_dim, activation, n_layers, dropout,
assign_dim, pool_ratio, cat=False): n_pooling, linkpred, batch_size, aggregator_type,
assign_dim, pool_ratio, cat=False):
super(DiffPool, self).__init__() super(DiffPool, self).__init__()
self.link_pred = linkpred self.link_pred = linkpred
self.concat = cat self.concat = cat
...@@ -42,55 +44,93 @@ class DiffPool(nn.Module): ...@@ -42,55 +44,93 @@ class DiffPool(nn.Module):
# constructing layers # constructing layers
# layers before diffpool # layers before diffpool
assert n_layers >= 3, "n_layers too few" assert n_layers >= 3, "n_layers too few"
self.gc_before_pool.append(GraphSageLayer(input_dim, hidden_dim, activation, dropout, aggregator_type, self.bn)) self.gc_before_pool.append(
GraphSageLayer(
input_dim,
hidden_dim,
activation,
dropout,
aggregator_type,
self.bn))
for _ in range(n_layers - 2): for _ in range(n_layers - 2):
self.gc_before_pool.append(GraphSageLayer(hidden_dim, hidden_dim, activation, dropout, aggregator_type, self.bn)) self.gc_before_pool.append(
self.gc_before_pool.append(GraphSageLayer(hidden_dim, embedding_dim, None, dropout, aggregator_type)) GraphSageLayer(
hidden_dim,
hidden_dim,
activation,
dropout,
aggregator_type,
self.bn))
self.gc_before_pool.append(
GraphSageLayer(
hidden_dim,
embedding_dim,
None,
dropout,
aggregator_type))
assign_dims = [] assign_dims = []
assign_dims.append(self.assign_dim) assign_dims.append(self.assign_dim)
if self.concat: if self.concat:
# diffpool layer receive pool_emedding_dim node feature tensor # diffpool layer receive pool_emedding_dim node feature tensor
# and return pool_embedding_dim node embedding # and return pool_embedding_dim node embedding
pool_embedding_dim = hidden_dim * (n_layers -1) + embedding_dim pool_embedding_dim = hidden_dim * (n_layers - 1) + embedding_dim
else: else:
pool_embedding_dim = embedding_dim pool_embedding_dim = embedding_dim
self.first_diffpool_layer = DiffPoolBatchedGraphLayer(pool_embedding_dim, self.assign_dim, hidden_dim,activation, dropout, aggregator_type, self.link_pred) self.first_diffpool_layer = DiffPoolBatchedGraphLayer(
pool_embedding_dim,
self.assign_dim,
hidden_dim,
activation,
dropout,
aggregator_type,
self.link_pred)
gc_after_per_pool = nn.ModuleList() gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, hidden_dim)) gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, hidden_dim))
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, embedding_dim)) gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, embedding_dim))
self.gc_after_pool.append(gc_after_per_pool) self.gc_after_pool.append(gc_after_per_pool)
self.assign_dim = int(self.assign_dim * pool_ratio) self.assign_dim = int(self.assign_dim * pool_ratio)
# each pooling module # each pooling module
for _ in range(n_pooling-1): for _ in range(n_pooling - 1):
self.diffpool_layers.append(BatchedDiffPool(pool_embedding_dim, self.assign_dim, hidden_dim, self.link_pred)) self.diffpool_layers.append(
BatchedDiffPool(
pool_embedding_dim,
self.assign_dim,
hidden_dim,
self.link_pred))
gc_after_per_pool = nn.ModuleList() gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, hidden_dim)) gc_after_per_pool.append(
gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, embedding_dim)) BatchedGraphSAGE(
hidden_dim, hidden_dim))
gc_after_per_pool.append(
BatchedGraphSAGE(
hidden_dim, embedding_dim))
self.gc_after_pool.append(gc_after_per_pool) self.gc_after_pool.append(gc_after_per_pool)
assign_dims.append(self.assign_dim) assign_dims.append(self.assign_dim)
self.assign_dim = int(self.assign_dim * pool_ratio) self.assign_dim = int(self.assign_dim * pool_ratio)
# predicting layer # predicting layer
if self.concat: if self.concat:
self.pred_input_dim = pool_embedding_dim*self.num_aggs*(n_pooling+1) self.pred_input_dim = pool_embedding_dim * \
self.num_aggs * (n_pooling + 1)
else: else:
self.pred_input_dim = embedding_dim*self.num_aggs self.pred_input_dim = embedding_dim * self.num_aggs
self.pred_layer = nn.Linear(self.pred_input_dim, label_dim) self.pred_layer = nn.Linear(self.pred_input_dim, label_dim)
# weight initialization # weight initialization
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
m.weight.data = init.xavier_uniform_(m.weight.data, m.weight.data = init.xavier_uniform_(m.weight.data,
gain=nn.init.calculate_gain('relu')) gain=nn.init.calculate_gain('relu'))
if m.bias is not None: if m.bias is not None:
m.bias.data = init.constant_(m.bias.data, 0.0) m.bias.data = init.constant_(m.bias.data, 0.0)
def gcn_forward(self, g, h, gc_layers, cat=False): def gcn_forward(self, g, h, gc_layers, cat=False):
""" """
Return gc_layer embedding cat. Return gc_layer embedding cat.
...@@ -102,22 +142,22 @@ class DiffPool(nn.Module): ...@@ -102,22 +142,22 @@ class DiffPool(nn.Module):
h = gc_layers[-1](g, h) h = gc_layers[-1](g, h)
block_readout.append(h) block_readout.append(h)
if cat: if cat:
block = torch.cat(block_readout, dim=1) # N x F, F = F1 + F2 + ... block = torch.cat(block_readout, dim=1) # N x F, F = F1 + F2 + ...
else: else:
block = h block = h
return block return block
def gcn_forward_tensorized(self, h, adj, gc_layers, cat=False): def gcn_forward_tensorized(self, h, adj, gc_layers, cat=False):
block_readout = [] block_readout = []
for gc_layer in gc_layers: for gc_layer in gc_layers:
h = gc_layer(h, adj) h = gc_layer(h, adj)
block_readout.append(h) block_readout.append(h)
if cat: if cat:
block = torch.cat(block_readout, dim=2) # N x F, F = F1 + F2 + ... block = torch.cat(block_readout, dim=2) # N x F, F = F1 + F2 + ...
else: else:
block = h block = h
return block return block
def forward(self, g): def forward(self, g):
self.link_pred_loss = [] self.link_pred_loss = []
self.entropy_loss = [] self.entropy_loss = []
...@@ -138,22 +178,23 @@ class DiffPool(nn.Module): ...@@ -138,22 +178,23 @@ class DiffPool(nn.Module):
if self.num_aggs == 2: if self.num_aggs == 2:
readout = dgl.max_nodes(g, 'h') readout = dgl.max_nodes(g, 'h')
out_all.append(readout) out_all.append(readout)
adj, h = self.first_diffpool_layer(g, g_embedding) adj, h = self.first_diffpool_layer(g, g_embedding)
node_per_pool_graph = int(adj.size()[0] / self.batch_size) node_per_pool_graph = int(adj.size()[0] / self.batch_size)
h, adj = batch2tensor(adj, h, node_per_pool_graph) h, adj = batch2tensor(adj, h, node_per_pool_graph)
h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[0], self.concat) h = self.gcn_forward_tensorized(
h, adj, self.gc_after_pool[0], self.concat)
readout = torch.sum(h, dim=1) readout = torch.sum(h, dim=1)
out_all.append(readout) out_all.append(readout)
if self.num_aggs == 2: if self.num_aggs == 2:
readout, _ = torch.max(h, dim=1) readout, _ = torch.max(h, dim=1)
out_all.append(readout) out_all.append(readout)
for i, diffpool_layer in enumerate(self.diffpool_layers): for i, diffpool_layer in enumerate(self.diffpool_layers):
h, adj = diffpool_layer(h, adj) h, adj = diffpool_layer(h, adj)
h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[i+1], self.concat) h = self.gcn_forward_tensorized(
h, adj, self.gc_after_pool[i + 1], self.concat)
readout = torch.sum(h, dim=1) readout = torch.sum(h, dim=1)
out_all.append(readout) out_all.append(readout)
if self.num_aggs == 2: if self.num_aggs == 2:
...@@ -165,7 +206,7 @@ class DiffPool(nn.Module): ...@@ -165,7 +206,7 @@ class DiffPool(nn.Module):
final_readout = readout final_readout = readout
ypred = self.pred_layer(final_readout) ypred = self.pred_layer(final_readout)
return ypred return ypred
def loss(self, pred, label): def loss(self, pred, label):
''' '''
loss function loss function
...@@ -175,5 +216,5 @@ class DiffPool(nn.Module): ...@@ -175,5 +216,5 @@ class DiffPool(nn.Module):
loss = criterion(pred, label) loss = criterion(pred, label)
for diffpool_layer in self.diffpool_layers: for diffpool_layer in self.diffpool_layers:
for key, value in diffpool_layer.loss_log.items(): for key, value in diffpool_layer.loss_log.items():
loss += value loss += value
return loss return loss
import torch import torch
import torch.nn as nn import torch.nn as nn
class EntropyLoss(nn.Module): class EntropyLoss(nn.Module):
# Return Scalar # Return Scalar
def forward(self, adj, anext, s_l): def forward(self, adj, anext, s_l):
entropy = (torch.distributions.Categorical(probs=s_l).entropy()).sum(-1).mean(-1) entropy = (torch.distributions.Categorical(
probs=s_l).entropy()).sum(-1).mean(-1)
assert not torch.isnan(entropy) assert not torch.isnan(entropy)
return entropy return entropy
...@@ -12,7 +14,7 @@ class EntropyLoss(nn.Module): ...@@ -12,7 +14,7 @@ class EntropyLoss(nn.Module):
class LinkPredLoss(nn.Module): class LinkPredLoss(nn.Module):
def forward(self, adj, anext, s_l): def forward(self, adj, anext, s_l):
link_pred_loss = (adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2)) link_pred_loss = (
adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2))
link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2)) link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))
return link_pred_loss.mean() return link_pred_loss.mean()
import torch as th import torch as th
from torch.autograd import Function from torch.autograd import Function
def batch2tensor(batch_adj, batch_feat, node_per_pool_graph): def batch2tensor(batch_adj, batch_feat, node_per_pool_graph):
""" """
transform a batched graph to batched adjacency tensor and node feature tensor transform a batched graph to batched adjacency tensor and node feature tensor
...@@ -9,13 +10,13 @@ def batch2tensor(batch_adj, batch_feat, node_per_pool_graph): ...@@ -9,13 +10,13 @@ def batch2tensor(batch_adj, batch_feat, node_per_pool_graph):
adj_list = [] adj_list = []
feat_list = [] feat_list = []
for i in range(batch_size): for i in range(batch_size):
start = i*node_per_pool_graph start = i * node_per_pool_graph
end = (i+1)*node_per_pool_graph end = (i + 1) * node_per_pool_graph
adj_list.append(batch_adj[start:end,start:end]) adj_list.append(batch_adj[start:end, start:end])
feat_list.append(batch_feat[start:end,:]) feat_list.append(batch_feat[start:end, :])
adj_list = list(map(lambda x : th.unsqueeze(x, 0), adj_list)) adj_list = list(map(lambda x: th.unsqueeze(x, 0), adj_list))
feat_list = list(map(lambda x : th.unsqueeze(x, 0), feat_list)) feat_list = list(map(lambda x: th.unsqueeze(x, 0), feat_list))
adj = th.cat(adj_list,dim=0) adj = th.cat(adj_list, dim=0)
feat = th.cat(feat_list, dim=0) feat = th.cat(feat_list, dim=0)
return feat, adj return feat, adj
...@@ -38,10 +39,7 @@ def masked_softmax(matrix, mask, dim=-1, memory_efficient=True, ...@@ -38,10 +39,7 @@ def masked_softmax(matrix, mask, dim=-1, memory_efficient=True,
result = result * mask result = result * mask
result = result / (result.sum(dim=dim, keepdim=True) + 1e-13) result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
else: else:
masked_matrix = matrix.masked_fill((1-mask).byte(), masked_matrix = matrix.masked_fill((1 - mask).byte(),
mask_fill_value) mask_fill_value)
result = th.nn.functional.softmax(masked_matrix, dim=dim) result = th.nn.functional.softmax(masked_matrix, dim=dim)
return result return result
...@@ -5,6 +5,7 @@ from model.tensorized_layers.assignment import DiffPoolAssignment ...@@ -5,6 +5,7 @@ from model.tensorized_layers.assignment import DiffPoolAssignment
from model.tensorized_layers.graphsage import BatchedGraphSAGE from model.tensorized_layers.graphsage import BatchedGraphSAGE
from model.loss import EntropyLoss, LinkPredLoss from model.loss import EntropyLoss, LinkPredLoss
class BatchedDiffPool(nn.Module): class BatchedDiffPool(nn.Module):
def __init__(self, nfeat, nnext, nhid, link_pred=False, entropy=True): def __init__(self, nfeat, nnext, nhid, link_pred=False, entropy=True):
super(BatchedDiffPool, self).__init__() super(BatchedDiffPool, self).__init__()
...@@ -20,7 +21,6 @@ class BatchedDiffPool(nn.Module): ...@@ -20,7 +21,6 @@ class BatchedDiffPool(nn.Module):
if entropy: if entropy:
self.reg_loss.append(EntropyLoss()) self.reg_loss.append(EntropyLoss())
def forward(self, x, adj, log=False): def forward(self, x, adj, log=False):
z_l = self.embed(x, adj) z_l = self.embed(x, adj)
s_l = self.assign(x, adj) s_l = self.assign(x, adj)
...@@ -35,5 +35,3 @@ class BatchedDiffPool(nn.Module): ...@@ -35,5 +35,3 @@ class BatchedDiffPool(nn.Module):
if log: if log:
self.log['a'] = anext.cpu().numpy() self.log['a'] = anext.cpu().numpy()
return xnext, anext return xnext, anext
...@@ -4,14 +4,17 @@ from torch.nn import functional as F ...@@ -4,14 +4,17 @@ from torch.nn import functional as F
class BatchedGraphSAGE(nn.Module): class BatchedGraphSAGE(nn.Module):
def __init__(self, infeat, outfeat, use_bn=True, mean=False, add_self=False): def __init__(self, infeat, outfeat, use_bn=True,
mean=False, add_self=False):
super().__init__() super().__init__()
self.add_self = add_self self.add_self = add_self
self.use_bn = use_bn self.use_bn = use_bn
self.mean = mean self.mean = mean
self.W = nn.Linear(infeat, outfeat, bias=True) self.W = nn.Linear(infeat, outfeat, bias=True)
nn.init.xavier_uniform_(self.W.weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(
self.W.weight,
gain=nn.init.calculate_gain('relu'))
def forward(self, x, adj): def forward(self, x, adj):
if self.use_bn and not hasattr(self, 'bn'): if self.use_bn and not hasattr(self, 'bn'):
......
...@@ -17,20 +17,37 @@ from dgl.data import tu ...@@ -17,20 +17,37 @@ from dgl.data import tu
from model.encoder import DiffPool from model.encoder import DiffPool
from data_utils import pre_process from data_utils import pre_process
def arg_parse(): def arg_parse():
''' '''
argument parser argument parser
''' '''
parser = argparse.ArgumentParser(description='DiffPool arguments') parser = argparse.ArgumentParser(description='DiffPool arguments')
parser.add_argument('--dataset', dest='dataset', help='Input Dataset') parser.add_argument('--dataset', dest='dataset', help='Input Dataset')
parser.add_argument('--pool_ratio', dest='pool_ratio', type=float, help='pooling ratio') parser.add_argument(
parser.add_argument('--num_pool', dest='num_pool', type=int, help='num_pooling layer') '--pool_ratio',
dest='pool_ratio',
type=float,
help='pooling ratio')
parser.add_argument(
'--num_pool',
dest='num_pool',
type=int,
help='num_pooling layer')
parser.add_argument('--no_link_pred', dest='linkpred', action='store_false', parser.add_argument('--no_link_pred', dest='linkpred', action='store_false',
help='switch of link prediction object') help='switch of link prediction object')
parser.add_argument('--cuda', dest='cuda', type=int, help='switch cuda') parser.add_argument('--cuda', dest='cuda', type=int, help='switch cuda')
parser.add_argument('--lr', dest='lr', type=float, help='learning rate') parser.add_argument('--lr', dest='lr', type=float, help='learning rate')
parser.add_argument('--clip', dest='clip', type=float, help='gradient clipping') parser.add_argument(
parser.add_argument('--batch-size', dest='batch_size', type=int, help='batch size') '--clip',
dest='clip',
type=float,
help='gradient clipping')
parser.add_argument(
'--batch-size',
dest='batch_size',
type=int,
help='batch size')
parser.add_argument('--epochs', dest='epoch', type=int, parser.add_argument('--epochs', dest='epoch', type=int,
help='num-of-epoch') help='num-of-epoch')
parser.add_argument('--train-ratio', dest='train_ratio', type=float, parser.add_argument('--train-ratio', dest='train_ratio', type=float,
...@@ -47,13 +64,16 @@ def arg_parse(): ...@@ -47,13 +64,16 @@ def arg_parse():
help='dropout rate') help='dropout rate')
parser.add_argument('--bias', dest='bias', action='store_const', parser.add_argument('--bias', dest='bias', action='store_const',
const=True, default=True, help='switch for bias') const=True, default=True, help='switch for bias')
parser.add_argument('--save_dir', dest='save_dir', help='model saving directory: SAVE_DICT/DATASET') parser.add_argument(
'--save_dir',
dest='save_dir',
help='model saving directory: SAVE_DICT/DATASET')
parser.add_argument('--load_epoch', dest='load_epoch', help='load trained model params from\ parser.add_argument('--load_epoch', dest='load_epoch', help='load trained model params from\
SAVE_DICT/DATASET/model-LOAD_EPOCH') SAVE_DICT/DATASET/model-LOAD_EPOCH')
parser.add_argument('--data_mode', dest='data_mode', help='data\ parser.add_argument('--data_mode', dest='data_mode', help='data\
preprocessing mode: default, id, degree, or one-hot\ preprocessing mode: default, id, degree, or one-hot\
vector of degree number', choices=['default', 'id', 'deg', vector of degree number', choices=['default', 'id', 'deg',
'deg_num']) 'deg_num'])
parser.set_defaults(dataset='ENZYMES', parser.set_defaults(dataset='ENZYMES',
pool_ratio=0.15, pool_ratio=0.15,
...@@ -73,9 +93,10 @@ def arg_parse(): ...@@ -73,9 +93,10 @@ def arg_parse():
bias=True, bias=True,
save_dir="./model_param", save_dir="./model_param",
load_epoch=-1, load_epoch=-1,
data_mode = 'default') data_mode='default')
return parser.parse_args() return parser.parse_args()
def prepare_data(dataset, prog_args, train=False, pre_process=None): def prepare_data(dataset, prog_args, train=False, pre_process=None):
''' '''
preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
...@@ -84,11 +105,11 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None): ...@@ -84,11 +105,11 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None):
shuffle = True shuffle = True
else: else:
shuffle = False shuffle = False
if pre_process: if pre_process:
pre_process(dataset, prog_args) pre_process(dataset, prog_args)
#dataset.set_fold(fold) # dataset.set_fold(fold)
return torch.utils.data.DataLoader(dataset, return torch.utils.data.DataLoader(dataset,
batch_size=prog_args.batch_size, batch_size=prog_args.batch_size,
shuffle=shuffle, shuffle=shuffle,
...@@ -101,13 +122,14 @@ def graph_classify_task(prog_args): ...@@ -101,13 +122,14 @@ def graph_classify_task(prog_args):
''' '''
perform graph classification task perform graph classification task
''' '''
dataset = tu.TUDataset(name=prog_args.dataset) dataset = tu.TUDataset(name=prog_args.dataset)
train_size = int(prog_args.train_ratio * len(dataset)) train_size = int(prog_args.train_ratio * len(dataset))
test_size = int(prog_args.test_ratio * len(dataset)) test_size = int(prog_args.test_ratio * len(dataset))
val_size = int(len(dataset) - train_size - test_size) val_size = int(len(dataset) - train_size - test_size)
dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(dataset, (train_size, val_size, test_size)) dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(
dataset, (train_size, val_size, test_size))
train_dataloader = prepare_data(dataset_train, prog_args, train=True, train_dataloader = prepare_data(dataset_train, prog_args, train=True,
pre_process=pre_process) pre_process=pre_process)
val_dataloader = prepare_data(dataset_val, prog_args, train=False, val_dataloader = prepare_data(dataset_val, prog_args, train=False,
...@@ -122,47 +144,52 @@ def graph_classify_task(prog_args): ...@@ -122,47 +144,52 @@ def graph_classify_task(prog_args):
print("number of graphs is", len(dataset)) print("number of graphs is", len(dataset))
# assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size" # assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size"
hidden_dim = 64 # used to be 64 hidden_dim = 64 # used to be 64
embedding_dim = 64 embedding_dim = 64
# calculate assignment dimension: pool_ratio * largest graph's maximum # calculate assignment dimension: pool_ratio * largest graph's maximum
# number of nodes in the dataset # number of nodes in the dataset
assign_dim = int(max_num_node * prog_args.pool_ratio) * prog_args.batch_size assign_dim = int(max_num_node * prog_args.pool_ratio) * \
prog_args.batch_size
print("++++++++++MODEL STATISTICS++++++++") print("++++++++++MODEL STATISTICS++++++++")
print("model hidden dim is", hidden_dim) print("model hidden dim is", hidden_dim)
print("model embedding dim for graph instance embedding", embedding_dim) print("model embedding dim for graph instance embedding", embedding_dim)
print("initial batched pool graph dim is", assign_dim) print("initial batched pool graph dim is", assign_dim)
activation = F.relu activation = F.relu
# initialize model # initialize model
# 'diffpool' : diffpool # 'diffpool' : diffpool
model = DiffPool(input_dim, model = DiffPool(input_dim,
hidden_dim, hidden_dim,
embedding_dim, embedding_dim,
label_dim, label_dim,
activation, activation,
prog_args.gc_per_block, prog_args.gc_per_block,
prog_args.dropout, prog_args.dropout,
prog_args.num_pool, prog_args.num_pool,
prog_args.linkpred, prog_args.linkpred,
prog_args.batch_size, prog_args.batch_size,
'meanpool', 'meanpool',
assign_dim, assign_dim,
prog_args.pool_ratio) prog_args.pool_ratio)
if prog_args.load_epoch >= 0 and prog_args.save_dir is not None: if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset\ model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
+ "/model.iter-" + str(prog_args.load_epoch))) + "/model.iter-" + str(prog_args.load_epoch)))
print("model init finished") print("model init finished")
print("MODEL:::::::", prog_args.method) print("MODEL:::::::", prog_args.method)
if prog_args.cuda: if prog_args.cuda:
model = model.cuda() model = model.cuda()
logger = train(train_dataloader, model, prog_args, val_dataset=val_dataloader) logger = train(
train_dataloader,
model,
prog_args,
val_dataset=val_dataloader)
result = evaluate(test_dataloader, model, prog_args, logger) result = evaluate(test_dataloader, model, prog_args, logger)
print("test accuracy {}%".format(result*100)) print("test accuracy {}%".format(result * 100))
def collate_fn(batch): def collate_fn(batch):
''' '''
...@@ -183,6 +210,7 @@ def collate_fn(batch): ...@@ -183,6 +210,7 @@ def collate_fn(batch):
return batched_graphs, batched_labels return batched_graphs, batched_labels
def train(dataset, model, prog_args, same_feat=True, val_dataset=None): def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
''' '''
training function training function
...@@ -193,7 +221,7 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None): ...@@ -193,7 +221,7 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
dataloader = dataset dataloader = dataset
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
model.parameters()), lr=0.001) model.parameters()), lr=0.001)
early_stopping_logger = {"best_epoch":-1, "val_acc": -1} early_stopping_logger = {"best_epoch": -1, "val_acc": -1}
if prog_args.cuda > 0: if prog_args.cuda > 0:
torch.cuda.set_device(0) torch.cuda.set_device(0)
...@@ -210,7 +238,6 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None): ...@@ -210,7 +238,6 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
batch_graph.ndata[key] = value.cuda() batch_graph.ndata[key] = value.cuda()
graph_labels = graph_labels.cuda() graph_labels = graph_labels.cuda()
model.zero_grad() model.zero_grad()
compute_start = time.time() compute_start = time.time()
ypred = model(batch_graph) ypred = model(batch_graph)
...@@ -225,32 +252,33 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None): ...@@ -225,32 +252,33 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip) nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
optimizer.step() optimizer.step()
train_accu = accum_correct / total train_accu = accum_correct / total
print("train accuracy for this epoch {} is {}%".format(epoch, print("train accuracy for this epoch {} is {}%".format(epoch,
train_accu*100)) train_accu * 100))
elapsed_time = time.time() - begin_time elapsed_time = time.time() - begin_time
print("loss {} with epoch time {} s & computation time {} s ".format(loss.item(), elapsed_time, computation_time)) print("loss {} with epoch time {} s & computation time {} s ".format(
loss.item(), elapsed_time, computation_time))
if val_dataset is not None: if val_dataset is not None:
result = evaluate(val_dataset, model, prog_args) result = evaluate(val_dataset, model, prog_args)
print("validation accuracy {}%".format(result*100)) print("validation accuracy {}%".format(result * 100))
if result >= early_stopping_logger['val_acc'] and result <=\ if result >= early_stopping_logger['val_acc'] and result <=\
train_accu: train_accu:
early_stopping_logger.update(best_epoch=epoch, val_acc=result) early_stopping_logger.update(best_epoch=epoch, val_acc=result)
if prog_args.save_dir is not None: if prog_args.save_dir is not None:
torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset\ torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset
+ "/model.iter-" + str(early_stopping_logger['best_epoch'])) + "/model.iter-" + str(early_stopping_logger['best_epoch']))
print("best epoch is EPOCH {}, val_acc is {}%".format(early_stopping_logger['best_epoch'], print("best epoch is EPOCH {}, val_acc is {}%".format(early_stopping_logger['best_epoch'],
early_stopping_logger['val_acc']*100)) early_stopping_logger['val_acc'] * 100))
torch.cuda.empty_cache() torch.cuda.empty_cache()
return early_stopping_logger return early_stopping_logger
def evaluate(dataloader, model, prog_args, logger=None): def evaluate(dataloader, model, prog_args, logger=None):
''' '''
evaluate function evaluate function
''' '''
if logger is not None and prog_args.save_dir is not None: if logger is not None and prog_args.save_dir is not None:
model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset\ model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
+ "/model.iter-" + str(logger['best_epoch']))) + "/model.iter-" + str(logger['best_epoch'])))
model.eval() model.eval()
correct_label = 0 correct_label = 0
...@@ -262,11 +290,12 @@ def evaluate(dataloader, model, prog_args, logger=None): ...@@ -262,11 +290,12 @@ def evaluate(dataloader, model, prog_args, logger=None):
graph_labels = graph_labels.cuda() graph_labels = graph_labels.cuda()
ypred = model(batch_graph) ypred = model(batch_graph)
indi = torch.argmax(ypred, dim=1) indi = torch.argmax(ypred, dim=1)
correct = torch.sum(indi==graph_labels) correct = torch.sum(indi == graph_labels)
correct_label += correct.item() correct_label += correct.item()
result = correct_label / (len(dataloader)*prog_args.batch_size) result = correct_label / (len(dataloader) * prog_args.batch_size)
return result return result
def main(): def main():
''' '''
main main
......
...@@ -24,7 +24,8 @@ class TUDataset(object): ...@@ -24,7 +24,8 @@ class TUDataset(object):
_url = r"https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/{}.zip" _url = r"https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/{}.zip"
def __init__(self, name, use_pandas=False, hidden_size=10, max_allow_node=None): def __init__(self, name, use_pandas=False,
hidden_size=10, max_allow_node=None):
self.name = name self.name = name
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -73,7 +74,8 @@ class TUDataset(object): ...@@ -73,7 +74,8 @@ class TUDataset(object):
print("No Node Label Data") print("No Node Label Data")
try: try:
DS_node_attr = np.loadtxt(self._file_path("node_attributes"), delimiter=",") DS_node_attr = np.loadtxt(
self._file_path("node_attributes"), delimiter=",")
for idxs, g in zip(node_idx_list, self.graph_lists): for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = DS_node_attr[idxs, :] g.ndata['feat'] = DS_node_attr[idxs, :]
self.data_mode = "node_attr" self.data_mode = "node_attr"
...@@ -84,8 +86,9 @@ class TUDataset(object): ...@@ -84,8 +86,9 @@ class TUDataset(object):
for idxs, g in zip(node_idx_list, self.graph_lists): for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = np.ones((g.number_of_nodes(), hidden_size)) g.ndata['feat'] = np.ones((g.number_of_nodes(), hidden_size))
self.data_mode = "constant" self.data_mode = "constant"
print("Use Constant one as Feature with hidden size {}".format(hidden_size)) print(
"Use Constant one as Feature with hidden size {}".format(hidden_size))
# remove graphs that are too large by user given standard # remove graphs that are too large by user given standard
# optional pre-processing steop in conformity with Rex Ying's original # optional pre-processing steop in conformity with Rex Ying's original
# DiffPool implementation # DiffPool implementation
...@@ -96,11 +99,12 @@ class TUDataset(object): ...@@ -96,11 +99,12 @@ class TUDataset(object):
if g.number_of_nodes() <= self.max_allow_node: if g.number_of_nodes() <= self.max_allow_node:
preserve_idx.append(i) preserve_idx.append(i)
self.graph_lists = [self.graph_lists[i] for i in preserve_idx] self.graph_lists = [self.graph_lists[i] for i in preserve_idx]
print("after pruning graphs that are too big : ", len(self.graph_lists)) print(
"after pruning graphs that are too big : ", len(
self.graph_lists))
self.graph_labels = [self.graph_labels[i] for i in preserve_idx] self.graph_labels = [self.graph_labels[i] for i in preserve_idx]
self.max_num_node = self.max_allow_node self.max_num_node = self.max_allow_node
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the i^th sample. """Get the i^th sample.
Paramters Paramters
...@@ -121,14 +125,18 @@ class TUDataset(object): ...@@ -121,14 +125,18 @@ class TUDataset(object):
def _download(self): def _download(self):
download_dir = get_download_dir() download_dir = get_download_dir()
zip_file_path = os.path.join(download_dir, "tu_{}.zip".format(self.name)) zip_file_path = os.path.join(
download_dir,
"tu_{}.zip".format(
self.name))
download(self._url.format(self.name), path=zip_file_path) download(self._url.format(self.name), path=zip_file_path)
extract_dir = os.path.join(download_dir, "tu_{}".format(self.name)) extract_dir = os.path.join(download_dir, "tu_{}".format(self.name))
extract_archive(zip_file_path, extract_dir) extract_archive(zip_file_path, extract_dir)
return extract_dir return extract_dir
def _file_path(self, category): def _file_path(self, category):
return os.path.join(self.extract_dir, self.name, "{}_{}.txt".format(self.name, category)) return os.path.join(self.extract_dir, self.name,
"{}_{}.txt".format(self.name, category))
@staticmethod @staticmethod
def _idx_from_zero(idx_tensor): def _idx_from_zero(idx_tensor):
...@@ -144,6 +152,5 @@ class TUDataset(object): ...@@ -144,6 +152,5 @@ class TUDataset(object):
def statistics(self): def statistics(self):
return self.graph_lists[0].ndata['feat'].shape[1],\ return self.graph_lists[0].ndata['feat'].shape[1],\
self.num_labels,\ self.num_labels,\
self.max_num_node self.max_num_node
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment