Unverified Commit 421b05e7 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Model] Clean up for Cluster GCN (#940)

* Update

* Update

* Update

* Update

* Update

* Update

* Fix

* Fix
parent 52c7ef49
...@@ -12,7 +12,7 @@ from dgl import DGLGraph ...@@ -12,7 +12,7 @@ from dgl import DGLGraph
from dgl.data import register_data_args from dgl.data import register_data_args
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from modules import GCNCluster, GraphSAGE from modules import GraphSAGE
from sampler import ClusterIter from sampler import ClusterIter
from utils import Logger, evaluate, save_log_dir, load_data from utils import Logger, evaluate, save_log_dir, load_data
...@@ -24,16 +24,13 @@ def main(args): ...@@ -24,16 +24,13 @@ def main(args):
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
multitask_data = set(['ppi', 'amazon', 'amazon-0.1', multitask_data = set(['ppi'])
'amazon-0.3', 'amazon2M', 'amazon2M-47'])
multitask = args.dataset in multitask_data multitask = args.dataset in multitask_data
# load and preprocess dataset # load and preprocess dataset
data = load_data(args) data = load_data(args)
train_nid = np.nonzero(data.train_mask)[0].astype(np.int64) train_nid = np.nonzero(data.train_mask)[0].astype(np.int64)
test_nid = np.nonzero(data.test_mask)[0].astype(np.int64)
# Normalize features # Normalize features
if args.normalize: if args.normalize:
...@@ -102,16 +99,13 @@ def main(args): ...@@ -102,16 +99,13 @@ def main(args):
print("features shape, ", features.shape) print("features shape, ", features.shape)
model_sel = {'GCN': GCNCluster, 'graphsage': GraphSAGE} model = GraphSAGE(in_feats,
model_class = model_sel[args.model_type]
print('using model:', model_class)
model = model_class(in_feats,
args.n_hidden, args.n_hidden,
n_classes, n_classes,
args.n_layers, args.n_layers,
F.relu, F.relu,
args.dropout, args.use_pp) args.dropout,
args.use_pp)
if cuda: if cuda:
model.cuda() model.cuda()
...@@ -135,9 +129,6 @@ def main(args): ...@@ -135,9 +129,6 @@ def main(args):
lr=args.lr, lr=args.lr,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
# initialize graph
dur = []
# set train_nids to cuda tensor # set train_nids to cuda tensor
if cuda: if cuda:
train_nid = torch.from_numpy(train_nid).cuda() train_nid = torch.from_numpy(train_nid).cuda()
...@@ -164,7 +155,8 @@ def main(args): ...@@ -164,7 +155,8 @@ def main(args):
# in PPI case, `log_every` is chosen to log one time per epoch. # in PPI case, `log_every` is chosen to log one time per epoch.
# Choose your log freq dynamically when you want more info within one epoch # Choose your log freq dynamically when you want more info within one epoch
if j % args.log_every == 0: if j % args.log_every == 0:
print(f"epoch:{epoch}/{args.n_epochs}, Iteration {j}/{len(cluster_iterator)}:training loss", loss.item()) print(f"epoch:{epoch}/{args.n_epochs}, Iteration {j}/"
f"{len(cluster_iterator)}:training loss", loss.item())
writer.add_scalar('train/loss', loss.item(), writer.add_scalar('train/loss', loss.item(),
global_step=j + epoch * len(cluster_iterator)) global_step=j + epoch * len(cluster_iterator))
print("current memory:", print("current memory:",
...@@ -193,12 +185,10 @@ def main(args): ...@@ -193,12 +185,10 @@ def main(args):
log_dir, 'best_model.pkl'))) log_dir, 'best_model.pkl')))
test_f1_mic, test_f1_mac = evaluate( test_f1_mic, test_f1_mac = evaluate(
model, g, labels, test_mask, multitask) model, g, labels, test_mask, multitask)
print( print("Test F1-mic{:.4f}, Test F1-mac{:.4f}". format(test_f1_mic, test_f1_mac))
"Test F1-mic{:.4f}, Test F1-mac{:.4f}". format(test_f1_mic, test_f1_mac))
writer.add_scalar('test/f1-mic', test_f1_mic) writer.add_scalar('test/f1-mic', test_f1_mic)
writer.add_scalar('test/f1-mac', test_f1_mac) writer.add_scalar('test/f1-mac', test_f1_mac)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN') parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser) register_data_args(parser)
...@@ -236,8 +226,6 @@ if __name__ == '__main__': ...@@ -236,8 +226,6 @@ if __name__ == '__main__':
help="whether to use validated best model to test") help="whether to use validated best model to test")
parser.add_argument("--weight-decay", type=float, default=5e-4, parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss") help="Weight for L2 loss")
parser.add_argument("--model-type", type=str, default='GCN',
help="model to be used")
parser.add_argument("--note", type=str, default='none', parser.add_argument("--note", type=str, default='none',
help="note for log dir") help="note for log dir")
......
...@@ -4,9 +4,7 @@ import dgl.function as fn ...@@ -4,9 +4,7 @@ import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
class GraphSAGELayer(nn.Module):
class GCNLayer(nn.Module):
def __init__(self, def __init__(self,
in_feats, in_feats,
out_feats, out_feats,
...@@ -15,12 +13,10 @@ class GCNLayer(nn.Module): ...@@ -15,12 +13,10 @@ class GCNLayer(nn.Module):
bias=True, bias=True,
use_pp=False, use_pp=False,
use_lynorm=True): use_lynorm=True):
super(GCNLayer, self).__init__() super(GraphSAGELayer, self).__init__()
self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats)) # The input feature size gets doubled as we concatenated the original
if bias: # features with the new features.
self.bias = nn.Parameter(torch.Tensor(out_feats)) self.linear = nn.Linear(2 * in_feats, out_feats, bias=bias)
else:
self.bias = None
self.activation = activation self.activation = activation
self.use_pp = use_pp self.use_pp = use_pp
if dropout: if dropout:
...@@ -34,16 +30,15 @@ class GCNLayer(nn.Module): ...@@ -34,16 +30,15 @@ class GCNLayer(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1)) stdv = 1. / math.sqrt(self.linear.weight.size(1))
self.weight.data.uniform_(-stdv, stdv) self.linear.weight.data.uniform_(-stdv, stdv)
if self.bias is not None: if self.linear.bias is not None:
self.bias.data.uniform_(-stdv, stdv) self.linear.bias.data.uniform_(-stdv, stdv)
def forward(self, g): def forward(self, g, h):
h = g.ndata['h'] g = g.local_var()
norm = self.get_norm(g)
if not self.use_pp or not self.training: if not self.use_pp or not self.training:
norm = self.get_norm(g)
g.ndata['h'] = h g.ndata['h'] = h
g.update_all(fn.copy_src(src='h', out='m'), g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
...@@ -52,75 +47,25 @@ class GCNLayer(nn.Module): ...@@ -52,75 +47,25 @@ class GCNLayer(nn.Module):
if self.dropout: if self.dropout:
h = self.dropout(h) h = self.dropout(h)
h = torch.mm(h, self.weight)
if self.bias is not None: h = self.linear(h)
h = h + self.bias
h = self.lynorm(h) h = self.lynorm(h)
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return h return h
def concat(self, h, ah, norm): def concat(self, h, ah, norm):
# normalization by square root of dst degree ah = ah * norm
return ah * norm h = torch.cat((h, ah), dim=1)
return h
def get_norm(self, g): def get_norm(self, g):
norm = 1. / g.in_degrees().float().unsqueeze(1) norm = 1. / g.in_degrees().float().unsqueeze(1)
# .sqrt()
norm[torch.isinf(norm)] = 0 norm[torch.isinf(norm)] = 0
norm = norm.to(self.weight.device) norm = norm.to(self.linear.weight.device)
return norm return norm
class GCNCluster(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
use_pp):
super(GCNCluster, self).__init__()
self.layers = nn.ModuleList()
# input layer
self.layers.append(
GCNLayer(in_feats, n_hidden, activation=activation, dropout=dropout, use_pp=use_pp, use_lynorm=True))
# hidden layers
for i in range(n_layers):
self.layers.append(
GCNLayer(n_hidden, n_hidden, activation=activation, dropout=dropout, use_lynorm=True
))
# output layer
self.layers.append(GCNLayer(n_hidden, n_classes,
activation=None, dropout=dropout, use_lynorm=False))
def forward(self, g):
g.ndata['h'] = g.ndata['features']
for i, layer in enumerate(self.layers):
g.ndata['h'] = layer(g)
h = g.ndata.pop('h')
return h
class GCNLayerSAGE(GCNLayer):
def __init__(self, *args, **xargs):
super(GCNLayerSAGE, self).__init__(*args, **xargs)
in_feats, out_feats = self.weight.shape
self.weight = nn.Parameter(torch.Tensor(2 * in_feats, out_feats))
self.reset_parameters()
def concat(self, h, ah, norm):
ah = ah * norm
h = torch.cat((h, ah), dim=1)
return h
class GraphSAGE(nn.Module): class GraphSAGE(nn.Module):
def __init__(self, def __init__(self,
in_feats, in_feats,
n_hidden, n_hidden,
...@@ -138,15 +83,14 @@ class GraphSAGE(nn.Module): ...@@ -138,15 +83,14 @@ class GraphSAGE(nn.Module):
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append( self.layers.append(
GCNLayerSAGE(n_hidden, n_hidden, activation=activation, dropout=dropout, use_pp=False, use_lynorm=True)) GCNLayerSAGE(n_hidden, n_hidden, activation=activation, dropout=dropout,
use_pp=False, use_lynorm=True))
# output layer # output layer
self.layers.append(GCNLayerSAGE(n_hidden, n_classes, activation=None, self.layers.append(GCNLayerSAGE(n_hidden, n_classes, activation=None,
dropout=dropout, use_pp=False, use_lynorm=False)) dropout=dropout, use_pp=False, use_lynorm=False))
def forward(self, g): def forward(self, g):
h = g.ndata['features'] h = g.ndata['features']
g.ndata['h'] = h
for layer in self.layers: for layer in self.layers:
g.ndata['h'] = layer(g) h = layer(g, h)
h = g.ndata.pop('h')
return h return h
...@@ -5,7 +5,6 @@ import numpy as np ...@@ -5,7 +5,6 @@ import numpy as np
from utils import arg_list from utils import arg_list
def get_partition_list(g, psize): def get_partition_list(g, psize):
tmp_time = time() tmp_time = time()
ng = g.to_networkx() ng = g.to_networkx()
...@@ -17,7 +16,6 @@ def get_partition_list(g, psize): ...@@ -17,7 +16,6 @@ def get_partition_list(g, psize):
al = arg_list(nd_group) al = arg_list(nd_group)
return al return al
def get_subgraph(g, par_arr, i, psize, batch_size): def get_subgraph(g, par_arr, i, psize, batch_size):
par_batch_ind_arr = [par_arr[s] for s in range( par_batch_ind_arr = [par_arr[s] for s in range(
i * batch_size, (i + 1) * batch_size) if s < psize] i * batch_size, (i + 1) * batch_size) if s < psize]
......
#!/bin/bash #!/bin/bash
python cluster_gcn.py --gpu 0 --dataset ppi --lr 1e-2 --weight-decay 0.0 --psize 50 --batch-size 1 --n-epochs 300 --n-hidden 2048 --n-layers 3 --log-every 100 --use-pp --self-loop --note self-loop-ppi-non-sym-ly3-pp-cluster-2-2-wd-0 --dropout 0.2 --model-type graphsage --use-val --normalize python cluster_gcn.py --gpu 0 --dataset ppi --lr 1e-2 --weight-decay 0.0 --psize 50 --batch-size 1 --n-epochs 300 \
\ No newline at end of file --n-hidden 2048 --n-layers 3 --log-every 100 --use-pp --self-loop \
--note self-loop-ppi-non-sym-ly3-pp-cluster-2-2-wd-0 --dropout 0.2 --use-val --normalize
#!/bin/bash #!/bin/bash
python cluster_gcn.py --gpu 0 --dataset reddit-self-loop --lr 1e-2 --weight-decay 0.0 --psize 1500 --batch-size 20 --n-epochs 30 --n-hidden 128 --n-layers 0 --log-every 100 --use-pp --self-loop --note self-loop-reddit-non-sym-ly3-pp-cluster-2-2-wd-5e-4 --dropout 0.2 --model-type graphsage --use-val --normalize python cluster_gcn.py --gpu 0 --dataset reddit-self-loop --lr 1e-2 --weight-decay 0.0 --psize 1500 --batch-size 20 \
\ No newline at end of file --n-epochs 30 --n-hidden 128 --n-layers 0 --log-every 100 --use-pp --self-loop \
--note self-loop-reddit-non-sym-ly3-pp-cluster-2-2-wd-5e-4 --dropout 0.2 --use-val --normalize
...@@ -2,17 +2,15 @@ import os ...@@ -2,17 +2,15 @@ import os
import random import random
import dgl.function as fn import dgl.function as fn
import numpy as np
import torch import torch
from partition_utils import * from partition_utils import *
class ClusterIter(object): class ClusterIter(object):
'''The partition sampler given a DGLGraph and partition number.
The metis is used as the graph partition backend.
''' '''
The partition sampler given a DGLGraph and partition number. The metis is used as the graph partition backend.
'''
def __init__(self, dn, g, psize, batch_size, seed_nid, use_pp=True): def __init__(self, dn, g, psize, batch_size, seed_nid, use_pp=True):
"""Initialize the sampler. """Initialize the sampler.
...@@ -65,8 +63,7 @@ class ClusterIter(object): ...@@ -65,8 +63,7 @@ class ClusterIter(object):
with torch.no_grad(): with torch.no_grad():
g.update_all(fn.copy_src(src='features', out='m'), g.update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='features'), fn.sum(msg='m', out='features'),
None None)
)
pre_feats = g.ndata['features'] * norm pre_feats = g.ndata['features'] * norm
# use graphsage embedding aggregation style # use graphsage embedding aggregation style
g.ndata['features'] = torch.cat([features, pre_feats], dim=1) g.ndata['features'] = torch.cat([features, pre_feats], dim=1)
......
...@@ -8,12 +8,8 @@ from dgl.data import PPIDataset ...@@ -8,12 +8,8 @@ from dgl.data import PPIDataset
from dgl.data import load_data as _load_data from dgl.data import load_data as _load_data
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
class Logger(object): class Logger(object):
''' '''A custom logger to log stdout to a logging file.'''
A custom logger to log stdout to a logging file.
'''
def __init__(self, path): def __init__(self, path):
"""Initialize the logger. """Initialize the logger.
...@@ -30,7 +26,6 @@ class Logger(object): ...@@ -30,7 +26,6 @@ class Logger(object):
print(s) print(s)
return return
def arg_list(labels): def arg_list(labels):
hist, indexes, inverse, counts = np.unique( hist, indexes, inverse, counts = np.unique(
labels, return_index=True, return_counts=True, return_inverse=True) labels, return_index=True, return_counts=True, return_inverse=True)
...@@ -39,7 +34,6 @@ def arg_list(labels): ...@@ -39,7 +34,6 @@ def arg_list(labels):
li.append(np.argwhere(inverse == h)) li.append(np.argwhere(inverse == h))
return li return li
def save_log_dir(args): def save_log_dir(args):
log_dir = './log/{}/{}'.format(args.dataset, args.note) log_dir = './log/{}/{}'.format(args.dataset, args.note)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
...@@ -54,7 +48,6 @@ def calc_f1(y_true, y_pred, multitask): ...@@ -54,7 +48,6 @@ def calc_f1(y_true, y_pred, multitask):
return f1_score(y_true, y_pred, average="micro"), \ return f1_score(y_true, y_pred, average="micro"), \
f1_score(y_true, y_pred, average="macro") f1_score(y_true, y_pred, average="macro")
def evaluate(model, g, labels, mask, multitask=False): def evaluate(model, g, labels, mask, multitask=False):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -65,11 +58,8 @@ def evaluate(model, g, labels, mask, multitask=False): ...@@ -65,11 +58,8 @@ def evaluate(model, g, labels, mask, multitask=False):
logits.cpu().numpy(), multitask) logits.cpu().numpy(), multitask)
return f1_mic, f1_mac return f1_mic, f1_mac
def load_data(args): def load_data(args):
''' '''Wraps the dgl's load_data utility to handle ppi special case'''
wraps the dgl's load_data utility to handle ppi special case
'''
if args.dataset != 'ppi': if args.dataset != 'ppi':
return _load_data(args) return _load_data(args)
train_dataset = PPIDataset('train') train_dataset = PPIDataset('train')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment