import argparse, time, os, pickle import random import sys sys.path.append("..") from utils.deduce import get_edge_dist import numpy as np import shutil import dgl import torch import torch.optim as optim from models import LANDER from dataset import LanderDataset from utils import evaluation, decode, build_next_level, stop_iterating from matplotlib import pyplot as plt import seaborn STATISTIC = False ########### # ArgParser parser = argparse.ArgumentParser() # Dataset parser.add_argument('--data_path', type=str, required=True) parser.add_argument('--model_filename', type=str, default='lander.pth') parser.add_argument('--faiss_gpu', action='store_true') parser.add_argument('--num_workers', type=int, default=0) parser.add_argument('--output_filename', type=str, default='data/features.pkl') # HyperParam parser.add_argument('--knn_k', type=int, default=10) parser.add_argument('--levels', type=int, default=1) parser.add_argument('--tau', type=float, default=0.5) parser.add_argument('--threshold', type=str, default='prob') parser.add_argument('--metrics', type=str, default='pairwise,bcubed,nmi') parser.add_argument('--early_stop', action='store_true') # Model parser.add_argument('--hidden', type=int, default=512) parser.add_argument('--num_conv', type=int, default=4) parser.add_argument('--dropout', type=float, default=0.) parser.add_argument('--gat', action='store_true') parser.add_argument('--gat_k', type=int, default=1) parser.add_argument('--balance', action='store_true') parser.add_argument('--use_cluster_feat', action='store_true') parser.add_argument('--use_focal_loss', action='store_true') parser.add_argument('--use_gt', action='store_true') # Subgraph parser.add_argument('--batch_size', type=int, default=4096) parser.add_argument('--mode', type=str, default="1head") parser.add_argument('--midpoint', type=str, default="false") parser.add_argument('--linsize', type=int, default=29011) parser.add_argument('--uinsize', type=int, default=18403) parser.add_argument('--inclasses', type=int, default=948) parser.add_argument('--thresh', type=float, default=1.0) parser.add_argument('--draw', type=str, default='false') parser.add_argument('--density_distance_pkl', type=str, default="density_distance.pkl") parser.add_argument('--density_lindistance_jpg', type=str, default="density_lindistance.jpg") args = parser.parse_args() print(args) MODE = args.mode linsize = args.linsize uinsize = args.uinsize inclasses = args.inclasses if args.draw == 'false': args.draw = False elif args.draw == 'true': args.draw = True ########################### # Environment Configuration if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') ################## # Data Preparation with open(args.data_path, 'rb') as f: loaded_data = pickle.load(f) path2idx, features, pred_labels, labels, masks = loaded_data idx2path = {v: k for k, v in path2idx.items()} gtlabels = labels orifeatures = features orilabels = gtlabels if MODE == "selectbydensity": lastusim = np.where(masks == 1) masks[lastusim] = 2 selectedidx = np.where(masks != 0) features = features[selectedidx] labels = gtlabels[selectedidx] selectmasks = masks[selectedidx] print("filtered features:", len(features)) print("mask0:", len(np.where(masks == 0)[0])) print("mask1:", len(np.where(masks == 1)[0])) print("mask2:", len(np.where(masks == 2)[0])) elif MODE == "recluster": selectedidx = np.where(masks == 1) features = features[selectedidx] labels = gtlabels[selectedidx] labelspred = pred_labels[selectedidx] selectmasks = masks[selectedidx] gtlabels = gtlabels[selectedidx] print("filtered features:", len(features)) else: selectedidx = np.where(masks != 0) features = features[selectedidx] labels = gtlabels[selectedidx] labelspred = pred_labels[selectedidx] selectmasks = masks[selectedidx] gtlabels = gtlabels[selectedidx] print("filtered features:", len(features)) global_features = features.copy() # global features dataset = LanderDataset(features=features, labels=labels, k=args.knn_k, levels=1, faiss_gpu=False) g = dataset.gs[0] g.ndata['pred_den'] = torch.zeros((g.number_of_nodes())) g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2)) global_labels = labels.copy() ids = np.arange(g.number_of_nodes()) global_edges = ([], []) global_peaks = np.array([], dtype=np.long) global_edges_len = len(global_edges[0]) global_num_nodes = g.number_of_nodes() global_densities = g.ndata['density'][:linsize] global_densities = np.sort(global_densities) xs = np.arange(len(global_densities)) fanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)] sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) # fix the number of edges test_loader = dgl.dataloading.NodeDataLoader( g, torch.arange(g.number_of_nodes()), sampler, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.num_workers ) ################## # Model Definition if not args.use_gt: feature_dim = g.ndata['features'].shape[1] model = LANDER(feature_dim=feature_dim, nhid=args.hidden, num_conv=args.num_conv, dropout=args.dropout, use_GAT=args.gat, K=args.gat_k, balance=args.balance, use_cluster_feat=args.use_cluster_feat, use_focal_loss=args.use_focal_loss) model.load_state_dict(torch.load(args.model_filename)) model = model.to(device) model.eval() # number of edges added is the indicator for early stopping num_edges_add_last_level = np.Inf ################################## # Predict connectivity and density for level in range(args.levels): print("level:", level) if not args.use_gt: total_batches = len(test_loader) for batch, minibatch in enumerate(test_loader): input_nodes, sub_g, bipartites = minibatch sub_g = sub_g.to(device) bipartites = [b.to(device) for b in bipartites] with torch.no_grad(): output_bipartite = model(bipartites) global_nid = output_bipartite.dstdata[dgl.NID] global_eid = output_bipartite.edata['global_eid'] g.ndata['pred_den'][global_nid] = output_bipartite.dstdata['pred_den'].to('cpu') g.edata['prob_conn'][global_eid] = output_bipartite.edata['prob_conn'].to('cpu') torch.cuda.empty_cache() if (batch + 1) % 10 == 0: print('Batch %d / %d for inference' % (batch, total_batches)) new_pred_labels, peaks, \ global_edges, global_pred_labels, global_peaks = decode(g, args.tau, args.threshold, args.use_gt, ids, global_edges, global_num_nodes, global_peaks) if level == 0: global_pred_densities = g.ndata['pred_den'] global_densities = g.ndata['density'] g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2)) ids = ids[peaks] new_global_edges_len = len(global_edges[0]) num_edges_add_this_level = new_global_edges_len - global_edges_len if stop_iterating(level, args.levels, args.early_stop, num_edges_add_this_level, num_edges_add_last_level, args.knn_k): break global_edges_len = new_global_edges_len num_edges_add_last_level = num_edges_add_this_level # build new dataset features, labels, cluster_features = build_next_level(features, labels, peaks, global_features, global_pred_labels, global_peaks) # After the first level, the number of nodes reduce a lot. Using cpu faiss is faster. dataset = LanderDataset(features=features, labels=labels, k=args.knn_k, levels=1, faiss_gpu=False, cluster_features=cluster_features) g = dataset.gs[0] g.ndata['pred_den'] = torch.zeros((g.number_of_nodes())) g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2)) test_loader = dgl.dataloading.NodeDataLoader( g, torch.arange(g.number_of_nodes()), sampler, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.num_workers ) if MODE == "selectbydensity": thresh = args.thresh global_pred_densities = np.array(global_pred_densities).astype(float) global_densities = np.array(global_densities).astype(float) distance = np.abs(global_pred_densities - global_densities) print("densities shape", global_pred_densities.shape) print(global_pred_densities.max(), global_pred_densities.min()) selectidx = np.where(global_pred_densities > thresh)[0] selected_pred_densities = global_pred_densities[selectidx] selected_densities = global_densities[selectidx] selected_distance = np.abs(selected_pred_densities - selected_densities) print(np.mean(selected_distance)) print("number of selected samples:", len(selectidx)) notselectidx = np.where(global_pred_densities <= thresh) print("not selected:", len(notselectidx[0])) global_pred_labels[notselectidx] = -1 global_pred_labels_new = np.zeros_like(orilabels) global_pred_labels_new[:] = -1 Tidx = np.where(masks != 2) print("T:", len(Tidx[0])) l_in_gt = orilabels[Tidx] l_in_features = orifeatures[Tidx] l_in_gt_new = np.zeros_like(l_in_gt) l_in_unique = np.unique(l_in_gt) for i in range(len(l_in_unique)): l_in = l_in_unique[i] l_in_idx = np.where(l_in_gt == l_in) l_in_gt_new[l_in_idx] = i print("len(l_in_unique)", len(l_in_unique)) if args.draw: prototypes = np.zeros((len(l_in_unique), features.shape[1])) for i in range(len(l_in_unique)): idx = np.where(l_in_gt_new == i) prototypes[i] = np.mean(l_in_features[idx], axis=0) similarity_matrix = torch.mm(torch.from_numpy(global_features.astype(np.float32)), torch.from_numpy(prototypes.astype(np.float32)).t()) similarity_matrix = (1 - similarity_matrix) / 2 minvalues, selected_pred_labels = torch.min(similarity_matrix, 1) # far-close ratio closeidx = np.where(minvalues < 0.15) faridx = np.where(minvalues >= 0.15) print("far:", len(faridx[0])) print("close:", len(closeidx[0])) cutidx = np.where(global_pred_densities >= 0.5) draw_minvalues = minvalues[cutidx] draw_densities = global_pred_densities[cutidx] with open(args.density_distance_pkl, 'wb') as f: pickle.dump((global_pred_densities, minvalues), f) print("dumped.") plt.clf() fig, ax = plt.subplots() import random if len(draw_densities) > 10000: samples_idx = random.sample(range(len(draw_minvalues)), 10000) ax.plot(draw_densities[random], draw_minvalues[random], color='tab:blue', marker='*', linestyle="None", markersize=1) else: ax.plot(draw_densities[random], draw_minvalues[random], color='tab:blue', marker='*', linestyle="None", markersize=1) plt.savefig(args.density_lindistance_jpg) global_pred_labels_new[Tidx] = l_in_gt_new global_pred_labels[selectidx] = global_pred_labels[selectidx] + len(l_in_unique) global_pred_labels_new[selectedidx] = global_pred_labels global_pred_labels = global_pred_labels_new linunique = np.unique(global_pred_labels[Tidx]) uunique = np.unique(global_pred_labels[selectedidx]) allnique = np.unique(global_pred_labels) print("labels") print(len(linunique), len(uunique), len(allnique)) global_masks = np.zeros_like(masks) global_masks[:] = 1 global_masks[np.array(selectedidx[0])[notselectidx]] = 2 Tidx = np.where(masks != 2) global_masks[Tidx] = 0 print("mask0", len(np.where(global_masks == 0)[0])) print("mask1", len(np.where(global_masks == 1)[0])) print("mask2", len(np.where(global_masks == 2)[0])) print("all", len(masks), len(orilabels), len(orifeatures)) global_gt_labels = orilabels if MODE == "recluster": global_pred_labels_new = np.zeros_like(orilabels) global_pred_labels_new[:] = -1 Tidx = np.where(masks == 0) print("T:", len(Tidx[0])) l_in_gt = orilabels[Tidx] l_in_features = orifeatures[Tidx] l_in_gt_new = np.zeros_like(l_in_gt) l_in_unique = np.unique(l_in_gt) for i in range(len(l_in_unique)): l_in = l_in_unique[i] l_in_idx = np.where(l_in_gt == l_in) l_in_gt_new[l_in_idx] = i print("len(l_in_unique)", len(l_in_unique)) global_pred_labels_new[Tidx] = l_in_gt_new print(len(global_pred_labels)) print(len(selectedidx[0])) global_pred_labels_new[selectedidx[0]] = global_pred_labels + len(l_in_unique) global_pred_labels = global_pred_labels_new global_masks = masks print("mask0", len(np.where(global_masks == 0)[0])) print("mask1", len(np.where(global_masks == 1)[0])) print("mask2", len(np.where(global_masks == 2)[0])) print("all", len(masks), len(orilabels), len(orifeatures)) global_gt_labels = orilabels if MODE == "donothing": global_masks = masks pass print("##################### L_in ########################") print(linsize) if len(global_pred_labels) >= linsize: evaluation(global_pred_labels[:linsize], global_gt_labels[:linsize], args.metrics) else: print("No samples in L_in!") print("##################### U_in ########################") uinidx = np.where(global_pred_labels[linsize:linsize + uinsize] != -1)[0] uinidx = uinidx + linsize print(len(uinidx)) if len(uinidx): evaluation(global_pred_labels[uinidx], global_gt_labels[uinidx], args.metrics) else: print("No samples in U_in!") print("##################### U_out ########################") uoutidx = np.where(global_pred_labels[linsize + uinsize:] != -1)[0] uoutidx = uoutidx + linsize + uinsize print(len(uoutidx)) if len(uoutidx): evaluation(global_pred_labels[uoutidx], global_gt_labels[uoutidx], args.metrics) else: print("No samples in U_out!") print("##################### U ########################") uidx = np.where(global_pred_labels[linsize:] != -1)[0] uidx = uidx + linsize print(len(uidx)) if len(uidx): evaluation(global_pred_labels[uidx], global_gt_labels[uidx], args.metrics) else: print("No samples in U!") print("##################### L+U ########################") luidx = np.where(global_pred_labels != -1)[0] print(len(luidx)) evaluation(global_pred_labels[luidx], global_gt_labels[luidx], args.metrics) print("##################### new selected samples ########################") sidx = np.where(global_masks == 1)[0] print(len(sidx)) if len(sidx) != 0: evaluation(global_pred_labels[sidx], global_gt_labels[sidx], args.metrics) print("##################### not selected samples ########################") nsidx = np.where(global_masks == 2)[0] print(len(nsidx)) if len(nsidx) != 0: evaluation(global_pred_labels[nsidx], global_gt_labels[nsidx], args.metrics) with open(args.output_filename, 'wb') as f: print(orifeatures.shape) print(global_pred_labels.shape) print(global_gt_labels.shape) print(global_masks.shape) pickle.dump([path2idx, orifeatures, global_pred_labels, global_gt_labels, global_masks], f)