test.py 4.09 KB
Newer Older
1
2
3
4
import argparse
import os
import pickle
import time
5

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
import dgl

8
import numpy as np
9
10
11
import torch
import torch.optim as optim
from dataset import LanderDataset
12
13
14
from models import LANDER
from utils import build_next_level, decode, evaluation, stop_iterating

15
16
17
18
19
###########
# ArgParser
parser = argparse.ArgumentParser()

# Dataset
20
21
22
23
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--model_filename", type=str, default="lander.pth")
parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument("--early_stop", action="store_true")
24
25

# HyperParam
26
27
28
29
30
parser.add_argument("--knn_k", type=int, default=10)
parser.add_argument("--levels", type=int, default=1)
parser.add_argument("--tau", type=float, default=0.5)
parser.add_argument("--threshold", type=str, default="prob")
parser.add_argument("--metrics", type=str, default="pairwise,bcubed,nmi")
31
32

# Model
33
34
35
36
37
38
39
40
41
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--num_conv", type=int, default=4)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gat", action="store_true")
parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument("--balance", action="store_true")
parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument("--use_focal_loss", action="store_true")
parser.add_argument("--use_gt", action="store_true")
42
43
44
45
46
47

args = parser.parse_args()

###########################
# Environment Configuration
if torch.cuda.is_available():
48
    device = torch.device("cuda")
49
else:
50
    device = torch.device("cpu")
51
52
53

##################
# Data Preparation
54
with open(args.data_path, "rb") as f:
55
56
    features, labels = pickle.load(f)
global_features = features.copy()
57
58
59
60
61
62
63
dataset = LanderDataset(
    features=features,
    labels=labels,
    k=args.knn_k,
    levels=1,
    faiss_gpu=args.faiss_gpu,
)
64
65
g = dataset.gs[0].to(device)
global_labels = labels.copy()
66
ids = np.arange(g.num_nodes())
67
68
global_edges = ([], [])
global_edges_len = len(global_edges[0])
69
global_num_nodes = g.num_nodes()
70
71
72
73

##################
# Model Definition
if not args.use_gt:
74
75
76
77
78
79
80
81
82
83
84
85
    feature_dim = g.ndata["features"].shape[1]
    model = LANDER(
        feature_dim=feature_dim,
        nhid=args.hidden,
        num_conv=args.num_conv,
        dropout=args.dropout,
        use_GAT=args.gat,
        K=args.gat_k,
        balance=args.balance,
        use_cluster_feat=args.use_cluster_feat,
        use_focal_loss=args.use_focal_loss,
    )
86
87
88
89
90
91
92
93
94
95
96
97
    model.load_state_dict(torch.load(args.model_filename))
    model = model.to(device)
    model.eval()

# number of edges added is the indicator for early stopping
num_edges_add_last_level = np.Inf
##################################
# Predict connectivity and density
for level in range(args.levels):
    if not args.use_gt:
        with torch.no_grad():
            g = model(g)
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    (
        new_pred_labels,
        peaks,
        global_edges,
        global_pred_labels,
        global_peaks,
    ) = decode(
        g,
        args.tau,
        args.threshold,
        args.use_gt,
        ids,
        global_edges,
        global_num_nodes,
    )
113
114
115
    ids = ids[peaks]
    new_global_edges_len = len(global_edges[0])
    num_edges_add_this_level = new_global_edges_len - global_edges_len
116
117
118
119
120
121
122
123
    if stop_iterating(
        level,
        args.levels,
        args.early_stop,
        num_edges_add_this_level,
        num_edges_add_last_level,
        args.knn_k,
    ):
124
125
126
127
128
        break
    global_edges_len = new_global_edges_len
    num_edges_add_last_level = num_edges_add_this_level

    # build new dataset
129
130
131
132
133
134
135
136
    features, labels, cluster_features = build_next_level(
        features,
        labels,
        peaks,
        global_features,
        global_pred_labels,
        global_peaks,
    )
137
    # After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
138
139
140
141
142
143
144
145
    dataset = LanderDataset(
        features=features,
        labels=labels,
        k=args.knn_k,
        levels=1,
        faiss_gpu=False,
        cluster_features=cluster_features,
    )
146
147
148
149
    if len(dataset.gs) == 0:
        break
    g = dataset.gs[0].to(device)
evaluation(global_pred_labels, global_labels, args.metrics)