Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
...@@ -280,7 +280,6 @@ _ = model.to(opt.device) ...@@ -280,7 +280,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train: # Place trainable parameter in list of parameters to train:
if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0: if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0:
all_but_fc_params = list( all_but_fc_params = list(
filter(lambda x: "last_linear" not in x[0], model.named_parameters()) filter(lambda x: "last_linear" not in x[0], model.named_parameters())
) )
...@@ -376,6 +375,8 @@ def same_model(model1, model2): ...@@ -376,6 +375,8 @@ def same_model(model1, model2):
"""============================================================================""" """============================================================================"""
#################### TRAINER FUNCTION ############################ #################### TRAINER FUNCTION ############################
def train_one_epoch_finetune( def train_one_epoch_finetune(
train_dataloader, model, optimizer, criterion, opt, epoch train_dataloader, model, optimizer, criterion, opt, epoch
...@@ -403,7 +404,6 @@ def train_one_epoch_finetune( ...@@ -403,7 +404,6 @@ def train_one_epoch_finetune(
train_dataloader, desc="Epoch {} Training gt labels...".format(epoch) train_dataloader, desc="Epoch {} Training gt labels...".format(epoch)
) )
for i, (class_labels, input) in enumerate(data_iterator): for i, (class_labels, input) in enumerate(data_iterator):
# Compute embeddings for input batch # Compute embeddings for input batch
features = model(input.to(opt.device)) features = model(input.to(opt.device))
......
...@@ -263,7 +263,6 @@ _ = model.to(opt.device) ...@@ -263,7 +263,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train: # Place trainable parameter in list of parameters to train:
if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0: if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0:
all_but_fc_params = list( all_but_fc_params = list(
filter(lambda x: "last_linear" not in x[0], model.named_parameters()) filter(lambda x: "last_linear" not in x[0], model.named_parameters())
) )
......
...@@ -11,6 +11,8 @@ import torch ...@@ -11,6 +11,8 @@ import torch
from scipy import sparse from scipy import sparse
"""=================================================================================================""" """================================================================================================="""
############ LOSS SELECTION FUNCTION ##################### ############ LOSS SELECTION FUNCTION #####################
def loss_select(loss, opt, to_optim): def loss_select(loss, opt, to_optim):
""" """
......
...@@ -281,7 +281,6 @@ _ = model.to(opt.device) ...@@ -281,7 +281,6 @@ _ = model.to(opt.device)
# Place trainable parameter in list of parameters to train: # Place trainable parameter in list of parameters to train:
if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0: if "fc_lr_mul" in vars(opt).keys() and opt.fc_lr_mul != 0:
all_but_fc_params = list( all_but_fc_params = list(
filter(lambda x: "last_linear" not in x[0], model.named_parameters()) filter(lambda x: "last_linear" not in x[0], model.named_parameters())
) )
......
import argparse, time, os, pickle import argparse, os, pickle, time
import random import random
import sys import sys
sys.path.append("..") sys.path.append("..")
from utils.deduce import get_edge_dist
import numpy as np
import shutil import shutil
import dgl import dgl
import numpy as np
import seaborn
import torch import torch
import torch.optim as optim import torch.optim as optim
from models import LANDER
from dataset import LanderDataset from dataset import LanderDataset
from utils import evaluation, decode, build_next_level, stop_iterating
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import seaborn
from models import LANDER
from utils import build_next_level, decode, evaluation, stop_iterating
from utils.deduce import get_edge_dist
STATISTIC = False STATISTIC = False
...@@ -25,43 +26,47 @@ STATISTIC = False ...@@ -25,43 +26,47 @@ STATISTIC = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Dataset # Dataset
parser.add_argument('--data_path', type=str, required=True) parser.add_argument("--data_path", type=str, required=True)
parser.add_argument('--model_filename', type=str, default='lander.pth') parser.add_argument("--model_filename", type=str, default="lander.pth")
parser.add_argument('--faiss_gpu', action='store_true') parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument('--num_workers', type=int, default=0) parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument('--output_filename', type=str, default='data/features.pkl') parser.add_argument("--output_filename", type=str, default="data/features.pkl")
# HyperParam # HyperParam
parser.add_argument('--knn_k', type=int, default=10) parser.add_argument("--knn_k", type=int, default=10)
parser.add_argument('--levels', type=int, default=1) parser.add_argument("--levels", type=int, default=1)
parser.add_argument('--tau', type=float, default=0.5) parser.add_argument("--tau", type=float, default=0.5)
parser.add_argument('--threshold', type=str, default='prob') parser.add_argument("--threshold", type=str, default="prob")
parser.add_argument('--metrics', type=str, default='pairwise,bcubed,nmi') parser.add_argument("--metrics", type=str, default="pairwise,bcubed,nmi")
parser.add_argument('--early_stop', action='store_true') parser.add_argument("--early_stop", action="store_true")
# Model # Model
parser.add_argument('--hidden', type=int, default=512) parser.add_argument("--hidden", type=int, default=512)
parser.add_argument('--num_conv', type=int, default=4) parser.add_argument("--num_conv", type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.) parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument('--gat', action='store_true') parser.add_argument("--gat", action="store_true")
parser.add_argument('--gat_k', type=int, default=1) parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument('--balance', action='store_true') parser.add_argument("--balance", action="store_true")
parser.add_argument('--use_cluster_feat', action='store_true') parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument('--use_focal_loss', action='store_true') parser.add_argument("--use_focal_loss", action="store_true")
parser.add_argument('--use_gt', action='store_true') parser.add_argument("--use_gt", action="store_true")
# Subgraph # Subgraph
parser.add_argument('--batch_size', type=int, default=4096) parser.add_argument("--batch_size", type=int, default=4096)
parser.add_argument('--mode', type=str, default="1head") parser.add_argument("--mode", type=str, default="1head")
parser.add_argument('--midpoint', type=str, default="false") parser.add_argument("--midpoint", type=str, default="false")
parser.add_argument('--linsize', type=int, default=29011) parser.add_argument("--linsize", type=int, default=29011)
parser.add_argument('--uinsize', type=int, default=18403) parser.add_argument("--uinsize", type=int, default=18403)
parser.add_argument('--inclasses', type=int, default=948) parser.add_argument("--inclasses", type=int, default=948)
parser.add_argument('--thresh', type=float, default=1.0) parser.add_argument("--thresh", type=float, default=1.0)
parser.add_argument('--draw', type=str, default='false') parser.add_argument("--draw", type=str, default="false")
parser.add_argument('--density_distance_pkl', type=str, default="density_distance.pkl") parser.add_argument(
parser.add_argument('--density_lindistance_jpg', type=str, default="density_lindistance.jpg") "--density_distance_pkl", type=str, default="density_distance.pkl"
)
parser.add_argument(
"--density_lindistance_jpg", type=str, default="density_lindistance.jpg"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
...@@ -70,21 +75,21 @@ linsize = args.linsize ...@@ -70,21 +75,21 @@ linsize = args.linsize
uinsize = args.uinsize uinsize = args.uinsize
inclasses = args.inclasses inclasses = args.inclasses
if args.draw == 'false': if args.draw == "false":
args.draw = False args.draw = False
elif args.draw == 'true': elif args.draw == "true":
args.draw = True args.draw = True
########################### ###########################
# Environment Configuration # Environment Configuration
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device("cuda")
else: else:
device = torch.device('cpu') device = torch.device("cpu")
################## ##################
# Data Preparation # Data Preparation
with open(args.data_path, 'rb') as f: with open(args.data_path, "rb") as f:
loaded_data = pickle.load(f) loaded_data = pickle.load(f)
path2idx, features, pred_labels, labels, masks = loaded_data path2idx, features, pred_labels, labels, masks = loaded_data
...@@ -123,11 +128,12 @@ else: ...@@ -123,11 +128,12 @@ else:
print("filtered features:", len(features)) print("filtered features:", len(features))
global_features = features.copy() # global features global_features = features.copy() # global features
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k, dataset = LanderDataset(
levels=1, faiss_gpu=False) features=features, labels=labels, k=args.knn_k, levels=1, faiss_gpu=False
)
g = dataset.gs[0] g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes())) g.ndata["pred_den"] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2)) g.edata["prob_conn"] = torch.zeros((g.number_of_edges(), 2))
global_labels = labels.copy() global_labels = labels.copy()
ids = np.arange(g.number_of_nodes()) ids = np.arange(g.number_of_nodes())
global_edges = ([], []) global_edges = ([], [])
...@@ -135,7 +141,7 @@ global_peaks = np.array([], dtype=np.long) ...@@ -135,7 +141,7 @@ global_peaks = np.array([], dtype=np.long)
global_edges_len = len(global_edges[0]) global_edges_len = len(global_edges[0])
global_num_nodes = g.number_of_nodes() global_num_nodes = g.number_of_nodes()
global_densities = g.ndata['density'][:linsize] global_densities = g.ndata["density"][:linsize]
global_densities = np.sort(global_densities) global_densities = np.sort(global_densities)
xs = np.arange(len(global_densities)) xs = np.arange(len(global_densities))
...@@ -143,23 +149,30 @@ fanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)] ...@@ -143,23 +149,30 @@ fanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges # fix the number of edges
test_loader = dgl.dataloading.DataLoader( test_loader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler, g,
torch.arange(g.number_of_nodes()),
sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=args.num_workers num_workers=args.num_workers,
) )
################## ##################
# Model Definition # Model Definition
if not args.use_gt: if not args.use_gt:
feature_dim = g.ndata['features'].shape[1] feature_dim = g.ndata["features"].shape[1]
model = LANDER(feature_dim=feature_dim, nhid=args.hidden, model = LANDER(
num_conv=args.num_conv, dropout=args.dropout, feature_dim=feature_dim,
use_GAT=args.gat, K=args.gat_k, nhid=args.hidden,
balance=args.balance, num_conv=args.num_conv,
use_cluster_feat=args.use_cluster_feat, dropout=args.dropout,
use_focal_loss=args.use_focal_loss) use_GAT=args.gat,
K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss,
)
model.load_state_dict(torch.load(args.model_filename)) model.load_state_dict(torch.load(args.model_filename))
model = model.to(device) model = model.to(device)
model.eval() model.eval()
...@@ -179,46 +192,82 @@ for level in range(args.levels): ...@@ -179,46 +192,82 @@ for level in range(args.levels):
with torch.no_grad(): with torch.no_grad():
output_bipartite = model(bipartites) output_bipartite = model(bipartites)
global_nid = output_bipartite.dstdata[dgl.NID] global_nid = output_bipartite.dstdata[dgl.NID]
global_eid = output_bipartite.edata['global_eid'] global_eid = output_bipartite.edata["global_eid"]
g.ndata['pred_den'][global_nid] = output_bipartite.dstdata['pred_den'].to('cpu') g.ndata["pred_den"][global_nid] = output_bipartite.dstdata[
g.edata['prob_conn'][global_eid] = output_bipartite.edata['prob_conn'].to('cpu') "pred_den"
].to("cpu")
g.edata["prob_conn"][global_eid] = output_bipartite.edata[
"prob_conn"
].to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if (batch + 1) % 10 == 0: if (batch + 1) % 10 == 0:
print('Batch %d / %d for inference' % (batch, total_batches)) print("Batch %d / %d for inference" % (batch, total_batches))
new_pred_labels, peaks, \ (
global_edges, global_pred_labels, global_peaks = decode(g, args.tau, args.threshold, args.use_gt, new_pred_labels,
ids, global_edges, global_num_nodes, peaks,
global_peaks) global_edges,
global_pred_labels,
global_peaks,
) = decode(
g,
args.tau,
args.threshold,
args.use_gt,
ids,
global_edges,
global_num_nodes,
global_peaks,
)
if level == 0: if level == 0:
global_pred_densities = g.ndata['pred_den'] global_pred_densities = g.ndata["pred_den"]
global_densities = g.ndata['density'] global_densities = g.ndata["density"]
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2)) g.edata["prob_conn"] = torch.zeros((g.number_of_edges(), 2))
ids = ids[peaks] ids = ids[peaks]
new_global_edges_len = len(global_edges[0]) new_global_edges_len = len(global_edges[0])
num_edges_add_this_level = new_global_edges_len - global_edges_len num_edges_add_this_level = new_global_edges_len - global_edges_len
if stop_iterating(level, args.levels, args.early_stop, num_edges_add_this_level, num_edges_add_last_level, if stop_iterating(
args.knn_k): level,
args.levels,
args.early_stop,
num_edges_add_this_level,
num_edges_add_last_level,
args.knn_k,
):
break break
global_edges_len = new_global_edges_len global_edges_len = new_global_edges_len
num_edges_add_last_level = num_edges_add_this_level num_edges_add_last_level = num_edges_add_this_level
# build new dataset # build new dataset
features, labels, cluster_features = build_next_level(features, labels, peaks, features, labels, cluster_features = build_next_level(
global_features, global_pred_labels, global_peaks) features,
labels,
peaks,
global_features,
global_pred_labels,
global_peaks,
)
# After the first level, the number of nodes reduce a lot. Using cpu faiss is faster. # After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
dataset = LanderDataset(features=features, labels=labels, k=args.knn_k, dataset = LanderDataset(
levels=1, faiss_gpu=False, cluster_features=cluster_features) features=features,
labels=labels,
k=args.knn_k,
levels=1,
faiss_gpu=False,
cluster_features=cluster_features,
)
g = dataset.gs[0] g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes())) g.ndata["pred_den"] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2)) g.edata["prob_conn"] = torch.zeros((g.number_of_edges(), 2))
test_loader = dgl.dataloading.DataLoader( test_loader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler, g,
torch.arange(g.number_of_nodes()),
sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=args.num_workers num_workers=args.num_workers,
) )
if MODE == "selectbydensity": if MODE == "selectbydensity":
...@@ -261,8 +310,10 @@ if MODE == "selectbydensity": ...@@ -261,8 +310,10 @@ if MODE == "selectbydensity":
idx = np.where(l_in_gt_new == i) idx = np.where(l_in_gt_new == i)
prototypes[i] = np.mean(l_in_features[idx], axis=0) prototypes[i] = np.mean(l_in_features[idx], axis=0)
similarity_matrix = torch.mm(torch.from_numpy(global_features.astype(np.float32)), similarity_matrix = torch.mm(
torch.from_numpy(prototypes.astype(np.float32)).t()) torch.from_numpy(global_features.astype(np.float32)),
torch.from_numpy(prototypes.astype(np.float32)).t(),
)
similarity_matrix = (1 - similarity_matrix) / 2 similarity_matrix = (1 - similarity_matrix) / 2
minvalues, selected_pred_labels = torch.min(similarity_matrix, 1) minvalues, selected_pred_labels = torch.min(similarity_matrix, 1)
# far-close ratio # far-close ratio
...@@ -274,7 +325,7 @@ if MODE == "selectbydensity": ...@@ -274,7 +325,7 @@ if MODE == "selectbydensity":
cutidx = np.where(global_pred_densities >= 0.5) cutidx = np.where(global_pred_densities >= 0.5)
draw_minvalues = minvalues[cutidx] draw_minvalues = minvalues[cutidx]
draw_densities = global_pred_densities[cutidx] draw_densities = global_pred_densities[cutidx]
with open(args.density_distance_pkl, 'wb') as f: with open(args.density_distance_pkl, "wb") as f:
pickle.dump((global_pred_densities, minvalues), f) pickle.dump((global_pred_densities, minvalues), f)
print("dumped.") print("dumped.")
plt.clf() plt.clf()
...@@ -283,15 +334,29 @@ if MODE == "selectbydensity": ...@@ -283,15 +334,29 @@ if MODE == "selectbydensity":
if len(draw_densities) > 10000: if len(draw_densities) > 10000:
samples_idx = random.sample(range(len(draw_minvalues)), 10000) samples_idx = random.sample(range(len(draw_minvalues)), 10000)
ax.plot(draw_densities[random], draw_minvalues[random], color='tab:blue', marker='*', linestyle="None", ax.plot(
markersize=1) draw_densities[random],
draw_minvalues[random],
color="tab:blue",
marker="*",
linestyle="None",
markersize=1,
)
else: else:
ax.plot(draw_densities[random], draw_minvalues[random], color='tab:blue', marker='*', linestyle="None", ax.plot(
markersize=1) draw_densities[random],
draw_minvalues[random],
color="tab:blue",
marker="*",
linestyle="None",
markersize=1,
)
plt.savefig(args.density_lindistance_jpg) plt.savefig(args.density_lindistance_jpg)
global_pred_labels_new[Tidx] = l_in_gt_new global_pred_labels_new[Tidx] = l_in_gt_new
global_pred_labels[selectidx] = global_pred_labels[selectidx] + len(l_in_unique) global_pred_labels[selectidx] = global_pred_labels[selectidx] + len(
l_in_unique
)
global_pred_labels_new[selectedidx] = global_pred_labels global_pred_labels_new[selectedidx] = global_pred_labels
global_pred_labels = global_pred_labels_new global_pred_labels = global_pred_labels_new
...@@ -332,7 +397,9 @@ if MODE == "recluster": ...@@ -332,7 +397,9 @@ if MODE == "recluster":
global_pred_labels_new[Tidx] = l_in_gt_new global_pred_labels_new[Tidx] = l_in_gt_new
print(len(global_pred_labels)) print(len(global_pred_labels))
print(len(selectedidx[0])) print(len(selectedidx[0]))
global_pred_labels_new[selectedidx[0]] = global_pred_labels + len(l_in_unique) global_pred_labels_new[selectedidx[0]] = global_pred_labels + len(
l_in_unique
)
global_pred_labels = global_pred_labels_new global_pred_labels = global_pred_labels_new
global_masks = masks global_masks = masks
print("mask0", len(np.where(global_masks == 0)[0])) print("mask0", len(np.where(global_masks == 0)[0]))
...@@ -348,23 +415,29 @@ if MODE == "donothing": ...@@ -348,23 +415,29 @@ if MODE == "donothing":
print("##################### L_in ########################") print("##################### L_in ########################")
print(linsize) print(linsize)
if len(global_pred_labels) >= linsize: if len(global_pred_labels) >= linsize:
evaluation(global_pred_labels[:linsize], global_gt_labels[:linsize], args.metrics) evaluation(
global_pred_labels[:linsize], global_gt_labels[:linsize], args.metrics
)
else: else:
print("No samples in L_in!") print("No samples in L_in!")
print("##################### U_in ########################") print("##################### U_in ########################")
uinidx = np.where(global_pred_labels[linsize:linsize + uinsize] != -1)[0] uinidx = np.where(global_pred_labels[linsize : linsize + uinsize] != -1)[0]
uinidx = uinidx + linsize uinidx = uinidx + linsize
print(len(uinidx)) print(len(uinidx))
if len(uinidx): if len(uinidx):
evaluation(global_pred_labels[uinidx], global_gt_labels[uinidx], args.metrics) evaluation(
global_pred_labels[uinidx], global_gt_labels[uinidx], args.metrics
)
else: else:
print("No samples in U_in!") print("No samples in U_in!")
print("##################### U_out ########################") print("##################### U_out ########################")
uoutidx = np.where(global_pred_labels[linsize + uinsize:] != -1)[0] uoutidx = np.where(global_pred_labels[linsize + uinsize :] != -1)[0]
uoutidx = uoutidx + linsize + uinsize uoutidx = uoutidx + linsize + uinsize
print(len(uoutidx)) print(len(uoutidx))
if len(uoutidx): if len(uoutidx):
evaluation(global_pred_labels[uoutidx], global_gt_labels[uoutidx], args.metrics) evaluation(
global_pred_labels[uoutidx], global_gt_labels[uoutidx], args.metrics
)
else: else:
print("No samples in U_out!") print("No samples in U_out!")
print("##################### U ########################") print("##################### U ########################")
...@@ -390,9 +463,18 @@ print(len(nsidx)) ...@@ -390,9 +463,18 @@ print(len(nsidx))
if len(nsidx) != 0: if len(nsidx) != 0:
evaluation(global_pred_labels[nsidx], global_gt_labels[nsidx], args.metrics) evaluation(global_pred_labels[nsidx], global_gt_labels[nsidx], args.metrics)
with open(args.output_filename, 'wb') as f: with open(args.output_filename, "wb") as f:
print(orifeatures.shape) print(orifeatures.shape)
print(global_pred_labels.shape) print(global_pred_labels.shape)
print(global_gt_labels.shape) print(global_gt_labels.shape)
print(global_masks.shape) print(global_masks.shape)
pickle.dump([path2idx, orifeatures, global_pred_labels, global_gt_labels, global_masks], f) pickle.dump(
[
path2idx,
orifeatures,
global_pred_labels,
global_gt_labels,
global_masks,
],
f,
)
import argparse, time, os, pickle import argparse, os, pickle, time
import random import random
import numpy as np import sys
import dgl import dgl
import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
import sys
sys.path.append("..") sys.path.append("..")
from models import LANDER
from dataset import LanderDataset from dataset import LanderDataset
from models import LANDER
########### ###########
# ArgParser # ArgParser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Dataset # Dataset
parser.add_argument('--data_path', type=str, required=True) parser.add_argument("--data_path", type=str, required=True)
parser.add_argument('--levels', type=str, default='1') parser.add_argument("--levels", type=str, default="1")
parser.add_argument('--faiss_gpu', action='store_true') parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument('--model_filename', type=str, default='lander.pth') parser.add_argument("--model_filename", type=str, default="lander.pth")
# KNN # KNN
parser.add_argument('--knn_k', type=str, default='10') parser.add_argument("--knn_k", type=str, default="10")
parser.add_argument('--num_workers', type=int, default=0) parser.add_argument("--num_workers", type=int, default=0)
# Model # Model
parser.add_argument('--hidden', type=int, default=512) parser.add_argument("--hidden", type=int, default=512)
parser.add_argument('--num_conv', type=int, default=1) parser.add_argument("--num_conv", type=int, default=1)
parser.add_argument('--dropout', type=float, default=0.) parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument('--gat', action='store_true') parser.add_argument("--gat", action="store_true")
parser.add_argument('--gat_k', type=int, default=1) parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument('--balance', action='store_true') parser.add_argument("--balance", action="store_true")
parser.add_argument('--use_cluster_feat', action='store_true') parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument('--use_focal_loss', action='store_true') parser.add_argument("--use_focal_loss", action="store_true")
# Training # Training
parser.add_argument('--epochs', type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument('--batch_size', type=int, default=1024) parser.add_argument("--batch_size", type=int, default=1024)
parser.add_argument('--lr', type=float, default=0.1) parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=1e-5) parser.add_argument("--weight_decay", type=float, default=1e-5)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
...@@ -49,9 +50,9 @@ print(args) ...@@ -49,9 +50,9 @@ print(args)
########################### ###########################
# Environment Configuration # Environment Configuration
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device("cuda")
else: else:
device = torch.device('cpu') device = torch.device("cpu")
def setup_seed(seed): def setup_seed(seed):
...@@ -66,7 +67,7 @@ def setup_seed(seed): ...@@ -66,7 +67,7 @@ def setup_seed(seed):
################## ##################
# Data Preparation # Data Preparation
with open(args.data_path, 'rb') as f: with open(args.data_path, "rb") as f:
path2idx, features, labels, _, masks = pickle.load(f) path2idx, features, labels, _, masks = pickle.load(f)
# lidx = np.where(masks==0) # lidx = np.where(masks==0)
# features = features[lidx] # features = features[lidx]
...@@ -75,8 +76,8 @@ with open(args.data_path, 'rb') as f: ...@@ -75,8 +76,8 @@ with open(args.data_path, 'rb') as f:
print("labels.shape:", labels.shape) print("labels.shape:", labels.shape)
k_list = [int(k) for k in args.knn_k.split(',')] k_list = [int(k) for k in args.knn_k.split(",")]
lvl_list = [int(l) for l in args.levels.split(',')] lvl_list = [int(l) for l in args.levels.split(",")]
gs = [] gs = []
nbrs = [] nbrs = []
ks = [] ks = []
...@@ -84,8 +85,13 @@ datasets = [] ...@@ -84,8 +85,13 @@ datasets = []
for k, l in zip(k_list, lvl_list): for k, l in zip(k_list, lvl_list):
print("k:", k) print("k:", k)
print("levels:", l) print("levels:", l)
dataset = LanderDataset(features=features, labels=labels, k=k, dataset = LanderDataset(
levels=l, faiss_gpu=args.faiss_gpu) features=features,
labels=labels,
k=k,
levels=l,
faiss_gpu=args.faiss_gpu,
)
gs += [g for g in dataset.gs] gs += [g for g in dataset.gs]
ks += [k for g in dataset.gs] ks += [k for g in dataset.gs]
nbrs += [nbr for nbr in dataset.nbrs] nbrs += [nbr for nbr in dataset.nbrs]
...@@ -101,24 +107,28 @@ for k, l in zip(k_list, lvl_list): ...@@ -101,24 +107,28 @@ for k, l in zip(k_list, lvl_list):
# nbrs += [nbr for nbr in dataset.nbrs] # nbrs += [nbr for nbr in dataset.nbrs]
with open("./dataset.pkl", 'wb') as f: with open("./dataset.pkl", "wb") as f:
pickle.dump(datasets, f) pickle.dump(datasets, f)
print('Dataset Prepared.') print("Dataset Prepared.")
def set_train_sampler_loader(g, k): def set_train_sampler_loader(g, k):
fanouts = [k-1 for i in range(args.num_conv + 1)] fanouts = [k - 1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges # fix the number of edges
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler, g,
torch.arange(g.number_of_nodes()),
sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers num_workers=args.num_workers,
) )
return train_dataloader return train_dataloader
train_loaders = [] train_loaders = []
for gidx, g in enumerate(gs): for gidx, g in enumerate(gs):
train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx]) train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx])
...@@ -126,31 +136,40 @@ for gidx, g in enumerate(gs): ...@@ -126,31 +136,40 @@ for gidx, g in enumerate(gs):
################## ##################
# Model Definition # Model Definition
feature_dim = gs[0].ndata['features'].shape[1] feature_dim = gs[0].ndata["features"].shape[1]
print("feature dimension:", feature_dim) print("feature dimension:", feature_dim)
model = LANDER(feature_dim=feature_dim, nhid=args.hidden, model = LANDER(
num_conv=args.num_conv, dropout=args.dropout, feature_dim=feature_dim,
use_GAT=args.gat, K=args.gat_k, nhid=args.hidden,
balance=args.balance, num_conv=args.num_conv,
use_cluster_feat=args.use_cluster_feat, dropout=args.dropout,
use_focal_loss=args.use_focal_loss) use_GAT=args.gat,
K=args.gat_k,
balance=args.balance,
use_cluster_feat=args.use_cluster_feat,
use_focal_loss=args.use_focal_loss,
)
model = model.to(device) model = model.to(device)
model.train() model.train()
################# #################
# Hyperparameters # Hyperparameters
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, opt = optim.SGD(
weight_decay=args.weight_decay) model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
# keep num_batch_per_loader the same for every sub_dataloader # keep num_batch_per_loader the same for every sub_dataloader
num_batch_per_loader = len(train_loaders[0]) num_batch_per_loader = len(train_loaders[0])
train_loaders = [iter(train_loader) for train_loader in train_loaders] train_loaders = [iter(train_loader) for train_loader in train_loaders]
num_loaders = len(train_loaders) num_loaders = len(train_loaders)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, scheduler = optim.lr_scheduler.CosineAnnealingLR(
T_max=args.epochs * num_batch_per_loader * num_loaders, opt, T_max=args.epochs * num_batch_per_loader * num_loaders, eta_min=1e-5
eta_min=1e-5) )
print('Start Training.') print("Start Training.")
############### ###############
# Training Loop # Training Loop
...@@ -163,7 +182,9 @@ for epoch in range(args.epochs): ...@@ -163,7 +182,9 @@ for epoch in range(args.epochs):
try: try:
minibatch = next(train_loaders[loader_id]) minibatch = next(train_loaders[loader_id])
except: except:
train_loaders[loader_id] = iter(set_train_sampler_loader(gs[loader_id], ks[loader_id])) train_loaders[loader_id] = iter(
set_train_sampler_loader(gs[loader_id], ks[loader_id])
)
minibatch = next(train_loaders[loader_id]) minibatch = next(train_loaders[loader_id])
input_nodes, sub_g, bipartites = minibatch input_nodes, sub_g, bipartites = minibatch
sub_g = sub_g.to(device) sub_g = sub_g.to(device)
...@@ -171,20 +192,38 @@ for epoch in range(args.epochs): ...@@ -171,20 +192,38 @@ for epoch in range(args.epochs):
# get the feature for the input_nodes # get the feature for the input_nodes
opt.zero_grad() opt.zero_grad()
output_bipartite = model(bipartites) output_bipartite = model(bipartites)
loss, loss_den_val, loss_conn_val = model.compute_loss(output_bipartite) loss, loss_den_val, loss_conn_val = model.compute_loss(
output_bipartite
)
loss_den_val_total.append(loss_den_val) loss_den_val_total.append(loss_den_val)
loss_conn_val_total.append(loss_conn_val) loss_conn_val_total.append(loss_conn_val)
loss_val_total.append(loss.item()) loss_val_total.append(loss.item())
loss.backward() loss.backward()
opt.step() opt.step()
if (batch + 1) % 10 == 0: if (batch + 1) % 10 == 0:
print('epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'% print(
(epoch, batch, num_batch_per_loader, loader_id, num_loaders, "epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
loss.item(), loss_den_val, loss_conn_val)) % (
epoch,
batch,
num_batch_per_loader,
loader_id,
num_loaders,
loss.item(),
loss_den_val,
loss_conn_val,
)
)
scheduler.step() scheduler.step()
print('epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f'% print(
(epoch, np.array(loss_val_total).mean(), "epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
np.array(loss_den_val_total).mean(), np.array(loss_conn_val_total).mean())) % (
epoch,
np.array(loss_val_total).mean(),
np.array(loss_den_val_total).mean(),
np.array(loss_conn_val_total).mean(),
)
)
torch.save(model.state_dict(), args.model_filename) torch.save(model.state_dict(), args.model_filename)
torch.save(model.state_dict(), args.model_filename) torch.save(model.state_dict(), args.model_filename)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import init
import dgl.function as fn
from dgl.nn.pytorch import GATConv from dgl.nn.pytorch import GATConv
from torch.nn import init
class GraphConvLayer(nn.Module): class GraphConvLayer(nn.Module):
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import dgl
import dgl.function as fn
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as fn
from .focal_loss import FocalLoss from .focal_loss import FocalLoss
from .graphconv import GraphConv from .graphconv import GraphConv
......
...@@ -3,6 +3,8 @@ import os ...@@ -3,6 +3,8 @@ import os
import pickle import pickle
import time import time
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
...@@ -10,8 +12,6 @@ from dataset import LanderDataset ...@@ -10,8 +12,6 @@ from dataset import LanderDataset
from models import LANDER from models import LANDER
from utils import build_next_level, decode, evaluation, stop_iterating from utils import build_next_level, decode, evaluation, stop_iterating
import dgl
########### ###########
# ArgParser # ArgParser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -3,6 +3,8 @@ import os ...@@ -3,6 +3,8 @@ import os
import pickle import pickle
import time import time
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
...@@ -10,8 +12,6 @@ from dataset import LanderDataset ...@@ -10,8 +12,6 @@ from dataset import LanderDataset
from models import LANDER from models import LANDER
from utils import build_next_level, decode, evaluation, stop_iterating from utils import build_next_level, decode, evaluation, stop_iterating
import dgl
########### ###########
# ArgParser # ArgParser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -3,14 +3,14 @@ import os ...@@ -3,14 +3,14 @@ import os
import pickle import pickle
import time import time
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
from dataset import LanderDataset from dataset import LanderDataset
from models import LANDER from models import LANDER
import dgl
########### ###########
# ArgParser # ArgParser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -50,6 +50,7 @@ if torch.cuda.is_available(): ...@@ -50,6 +50,7 @@ if torch.cuda.is_available():
else: else:
device = torch.device("cpu") device = torch.device("cpu")
################## ##################
# Data Preparation # Data Preparation
def prepare_dataset_graphs(data_path, k_list, lvl_list): def prepare_dataset_graphs(data_path, k_list, lvl_list):
......
...@@ -3,14 +3,14 @@ import os ...@@ -3,14 +3,14 @@ import os
import pickle import pickle
import time import time
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
from dataset import LanderDataset from dataset import LanderDataset
from models import LANDER from models import LANDER
import dgl
########### ###########
# ArgParser # ArgParser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
""" """
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
""" """
import dgl
import numpy as np import numpy as np
import torch import torch
from sklearn import mixture from sklearn import mixture
import dgl
from .density import density_to_peaks, density_to_peaks_vectorize from .density import density_to_peaks, density_to_peaks_vectorize
__all__ = [ __all__ = [
......
...@@ -6,7 +6,7 @@ import inspect ...@@ -6,7 +6,7 @@ import inspect
import numpy as np import numpy as np
from clustering_benchmark import ClusteringBenchmark from clustering_benchmark import ClusteringBenchmark
from utils import TextColors, Timer, metrics from utils import metrics, TextColors, Timer
def _read_meta(fn): def _read_meta(fn):
......
...@@ -101,7 +101,6 @@ def faiss_search_knn( ...@@ -101,7 +101,6 @@ def faiss_search_knn(
sort=True, sort=True,
verbose=False, verbose=False,
): ):
dists, nbrs = faiss_search_approx_knn( dists, nbrs = faiss_search_approx_knn(
query=feat, target=feat, k=k, nprobe=nprobe, verbose=verbose query=feat, target=feat, k=k, nprobe=nprobe, verbose=verbose
) )
......
...@@ -70,7 +70,6 @@ def svc_classify(x, y, search): ...@@ -70,7 +70,6 @@ def svc_classify(x, y, search):
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None) kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
accuracies = [] accuracies = []
for train_index, test_index in kf.split(x, y): for train_index, test_index in kf.split(x, y):
x_train, x_test = x[train_index], x[test_index] x_train, x_test = x[train_index], x[test_index]
y_train, y_test = y[train_index], y[test_index] y_train, y_test = y[train_index], y[test_index]
......
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import GRU, BatchNorm1d, Linear, ModuleList, ReLU, Sequential
from utils import global_global_loss_, local_global_loss_
from dgl.nn import GINConv, NNConv, Set2Set from dgl.nn import GINConv, NNConv, Set2Set
from dgl.nn.pytorch.glob import SumPooling from dgl.nn.pytorch.glob import SumPooling
from torch.nn import BatchNorm1d, GRU, Linear, ModuleList, ReLU, Sequential
from utils import global_global_loss_, local_global_loss_
""" Feedforward neural network""" """ Feedforward neural network"""
...@@ -102,7 +102,6 @@ class GINEncoder(nn.Module): ...@@ -102,7 +102,6 @@ class GINEncoder(nn.Module):
self.pool = SumPooling() self.pool = SumPooling()
def forward(self, graph, feat): def forward(self, graph, feat):
xs = [] xs = []
x = feat x = feat
for i in range(self.n_layer): for i in range(self.n_layer):
...@@ -163,7 +162,6 @@ class InfoGraph(nn.Module): ...@@ -163,7 +162,6 @@ class InfoGraph(nn.Module):
return global_emb return global_emb
def forward(self, graph, feat, graph_id): def forward(self, graph, feat, graph_id):
global_emb, local_emb = self.encoder(graph, feat) global_emb, local_emb = self.encoder(graph, feat)
global_h = self.global_d(global_emb) # global hidden representation global_h = self.global_d(global_emb) # global hidden representation
...@@ -221,7 +219,6 @@ class NNConvEncoder(nn.Module): ...@@ -221,7 +219,6 @@ class NNConvEncoder(nn.Module):
self.set2set = Set2Set(hid_dim, n_iters=3, n_layers=1) self.set2set = Set2Set(hid_dim, n_iters=3, n_layers=1)
def forward(self, graph, nfeat, efeat): def forward(self, graph, nfeat, efeat):
out = F.relu(self.lin0(nfeat)) out = F.relu(self.lin0(nfeat))
h = out.unsqueeze(0) h = out.unsqueeze(0)
...@@ -279,7 +276,6 @@ class InfoGraphS(nn.Module): ...@@ -279,7 +276,6 @@ class InfoGraphS(nn.Module):
self.unsup_d = FeedforwardNetwork(2 * hid_dim, hid_dim) self.unsup_d = FeedforwardNetwork(2 * hid_dim, hid_dim)
def forward(self, graph, nfeat, efeat): def forward(self, graph, nfeat, efeat):
sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat) sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)
sup_global_pred = self.fc2(F.relu(self.fc1(sup_global_emb))) sup_global_pred = self.fc2(F.relu(self.fc1(sup_global_emb)))
...@@ -288,7 +284,6 @@ class InfoGraphS(nn.Module): ...@@ -288,7 +284,6 @@ class InfoGraphS(nn.Module):
return sup_global_pred return sup_global_pred
def unsup_forward(self, graph, nfeat, efeat, graph_id): def unsup_forward(self, graph, nfeat, efeat, graph_id):
sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat) sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)
unsup_global_emb, unsup_local_emb = self.unsup_encoder( unsup_global_emb, unsup_local_emb = self.unsup_encoder(
graph, nfeat, efeat graph, nfeat, efeat
......
import argparse import argparse
import dgl
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
from model import InfoGraphS
import dgl
from dgl.data import QM9EdgeDataset from dgl.data import QM9EdgeDataset
from dgl.data.utils import Subset from dgl.data.utils import Subset
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
from model import InfoGraphS
def argument(): def argument():
...@@ -160,7 +160,6 @@ def evaluate(model, loader, num, device): ...@@ -160,7 +160,6 @@ def evaluate(model, loader, num, device):
if __name__ == "__main__": if __name__ == "__main__":
# Step 1: Prepare graph data ===================================== # # Step 1: Prepare graph data ===================================== #
args = argument() args = argument()
label_keys = [args.target] label_keys = [args.target]
......
import argparse import argparse
import torch as th
from evaluate_embedding import evaluate_embedding
from model import InfoGraph
import dgl import dgl
import torch as th
from dgl.data import GINDataset from dgl.data import GINDataset
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
from evaluate_embedding import evaluate_embedding
from model import InfoGraph
def argument(): def argument():
...@@ -75,7 +75,6 @@ def collate(samples): ...@@ -75,7 +75,6 @@ def collate(samples):
if __name__ == "__main__": if __name__ == "__main__":
# Step 1: Prepare graph data ===================================== # # Step 1: Prepare graph data ===================================== #
args = argument() args = argument()
print(args) print(args)
...@@ -131,7 +130,6 @@ if __name__ == "__main__": ...@@ -131,7 +130,6 @@ if __name__ == "__main__":
model.train() model.train()
for graph, label in dataloader: for graph, label in dataloader:
graph = graph.to(args.device) graph = graph.to(args.device)
feat = graph.ndata["attr"] feat = graph.ndata["attr"]
graph_id = graph.ndata["graph_id"] graph_id = graph.ndata["graph_id"]
...@@ -147,7 +145,6 @@ if __name__ == "__main__": ...@@ -147,7 +145,6 @@ if __name__ == "__main__":
print("Epoch {}, Loss {:.4f}".format(epoch, loss_all)) print("Epoch {}, Loss {:.4f}".format(epoch, loss_all))
if epoch % args.log_interval == 0: if epoch % args.log_interval == 0:
# evaluate embeddings # evaluate embeddings
model.eval() model.eval()
emb = model.get_embedding(wholegraph, wholefeat).cpu() emb = model.get_embedding(wholegraph, wholefeat).cpu()
......
...@@ -41,7 +41,6 @@ def get_negative_expectation(q_samples, average=True): ...@@ -41,7 +41,6 @@ def get_negative_expectation(q_samples, average=True):
def local_global_loss_(l_enc, g_enc, graph_id): def local_global_loss_(l_enc, g_enc, graph_id):
num_graphs = g_enc.shape[0] num_graphs = g_enc.shape[0]
num_nodes = l_enc.shape[0] num_nodes = l_enc.shape[0]
...@@ -51,7 +50,6 @@ def local_global_loss_(l_enc, g_enc, graph_id): ...@@ -51,7 +50,6 @@ def local_global_loss_(l_enc, g_enc, graph_id):
neg_mask = th.ones((num_nodes, num_graphs)).to(device) neg_mask = th.ones((num_nodes, num_graphs)).to(device)
for nodeidx, graphidx in enumerate(graph_id): for nodeidx, graphidx in enumerate(graph_id):
pos_mask[nodeidx][graphidx] = 1.0 pos_mask[nodeidx][graphidx] = 1.0
neg_mask[nodeidx][graphidx] = 0.0 neg_mask[nodeidx][graphidx] = 0.0
...@@ -66,7 +64,6 @@ def local_global_loss_(l_enc, g_enc, graph_id): ...@@ -66,7 +64,6 @@ def local_global_loss_(l_enc, g_enc, graph_id):
def global_global_loss_(sup_enc, unsup_enc): def global_global_loss_(sup_enc, unsup_enc):
num_graphs = sup_enc.shape[0] num_graphs = sup_enc.shape[0]
device = sup_enc.device device = sup_enc.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment