import argparse import os import pickle import time import numpy as np import torch import torch.optim as optim from dataset import LanderDataset from models import LANDER from utils import build_next_level, decode, evaluation, stop_iterating import dgl ########### # ArgParser parser = argparse.ArgumentParser() # Dataset parser.add_argument("--data_path", type=str, required=True) parser.add_argument("--model_filename", type=str, default="lander.pth") parser.add_argument("--faiss_gpu", action="store_true") parser.add_argument("--early_stop", action="store_true") # HyperParam parser.add_argument("--knn_k", type=int, default=10) parser.add_argument("--levels", type=int, default=1) parser.add_argument("--tau", type=float, default=0.5) parser.add_argument("--threshold", type=str, default="prob") parser.add_argument("--metrics", type=str, default="pairwise,bcubed,nmi") # Model parser.add_argument("--hidden", type=int, default=512) parser.add_argument("--num_conv", type=int, default=4) parser.add_argument("--dropout", type=float, default=0.0) parser.add_argument("--gat", action="store_true") parser.add_argument("--gat_k", type=int, default=1) parser.add_argument("--balance", action="store_true") parser.add_argument("--use_cluster_feat", action="store_true") parser.add_argument("--use_focal_loss", action="store_true") parser.add_argument("--use_gt", action="store_true") args = parser.parse_args() ########################### # Environment Configuration if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") ################## # Data Preparation with open(args.data_path, "rb") as f: features, labels = pickle.load(f) global_features = features.copy() dataset = LanderDataset( features=features, labels=labels, k=args.knn_k, levels=1, faiss_gpu=args.faiss_gpu, ) g = dataset.gs[0].to(device) global_labels = labels.copy() ids = np.arange(g.number_of_nodes()) global_edges = ([], []) global_edges_len = len(global_edges[0]) global_num_nodes = g.number_of_nodes() ################## # Model Definition if not args.use_gt: feature_dim = g.ndata["features"].shape[1] model = LANDER( feature_dim=feature_dim, nhid=args.hidden, num_conv=args.num_conv, dropout=args.dropout, use_GAT=args.gat, K=args.gat_k, balance=args.balance, use_cluster_feat=args.use_cluster_feat, use_focal_loss=args.use_focal_loss, ) model.load_state_dict(torch.load(args.model_filename)) model = model.to(device) model.eval() # number of edges added is the indicator for early stopping num_edges_add_last_level = np.Inf ################################## # Predict connectivity and density for level in range(args.levels): if not args.use_gt: with torch.no_grad(): g = model(g) ( new_pred_labels, peaks, global_edges, global_pred_labels, global_peaks, ) = decode( g, args.tau, args.threshold, args.use_gt, ids, global_edges, global_num_nodes, ) ids = ids[peaks] new_global_edges_len = len(global_edges[0]) num_edges_add_this_level = new_global_edges_len - global_edges_len if stop_iterating( level, args.levels, args.early_stop, num_edges_add_this_level, num_edges_add_last_level, args.knn_k, ): break global_edges_len = new_global_edges_len num_edges_add_last_level = num_edges_add_this_level # build new dataset features, labels, cluster_features = build_next_level( features, labels, peaks, global_features, global_pred_labels, global_peaks, ) # After the first level, the number of nodes reduce a lot. Using cpu faiss is faster. dataset = LanderDataset( features=features, labels=labels, k=args.knn_k, levels=1, faiss_gpu=False, cluster_features=cluster_features, ) if len(dataset.gs) == 0: break g = dataset.gs[0].to(device) evaluation(global_pred_labels, global_labels, args.metrics)