test_subg.py 5.85 KB
Newer Older
1
2
3
4
import argparse
import os
import pickle
import time
5

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
import dgl

8
import numpy as np
9
10
11
import torch
import torch.optim as optim
from dataset import LanderDataset
12
13
14
from models import LANDER
from utils import build_next_level, decode, evaluation, stop_iterating

15
16
17
18
19
###########
# ArgParser
parser = argparse.ArgumentParser()

# Dataset
20
21
22
23
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--model_filename", type=str, default="lander.pth")
parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument("--num_workers", type=int, default=0)
24
25

# HyperParam
26
27
28
29
30
31
parser.add_argument("--knn_k", type=int, default=10)
parser.add_argument("--levels", type=int, default=1)
parser.add_argument("--tau", type=float, default=0.5)
parser.add_argument("--threshold", type=str, default="prob")
parser.add_argument("--metrics", type=str, default="pairwise,bcubed,nmi")
parser.add_argument("--early_stop", action="store_true")
32
33

# Model
34
35
36
37
38
39
40
41
42
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--num_conv", type=int, default=4)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gat", action="store_true")
parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument("--balance", action="store_true")
parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument("--use_focal_loss", action="store_true")
parser.add_argument("--use_gt", action="store_true")
43
44

# Subgraph
45
parser.add_argument("--batch_size", type=int, default=4096)
46
47
48
49
50
51
52

args = parser.parse_args()
print(args)

###########################
# Environment Configuration
if torch.cuda.is_available():
53
    device = torch.device("cuda")
54
else:
55
    device = torch.device("cpu")
56
57
58

##################
# Data Preparation
59
with open(args.data_path, "rb") as f:
60
61
    features, labels = pickle.load(f)
global_features = features.copy()
62
63
64
65
66
67
68
dataset = LanderDataset(
    features=features,
    labels=labels,
    k=args.knn_k,
    levels=1,
    faiss_gpu=args.faiss_gpu,
)
69
g = dataset.gs[0]
70
71
g.ndata["pred_den"] = torch.zeros((g.number_of_nodes()))
g.edata["prob_conn"] = torch.zeros((g.number_of_edges(), 2))
72
73
74
75
76
77
78
global_labels = labels.copy()
ids = np.arange(g.number_of_nodes())
global_edges = ([], [])
global_peaks = np.array([], dtype=np.long)
global_edges_len = len(global_edges[0])
global_num_nodes = g.number_of_nodes()

79
fanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)]
80
81
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges
82
test_loader = dgl.dataloading.DataLoader(
83
84
85
    g,
    torch.arange(g.number_of_nodes()),
    sampler,
86
87
88
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=False,
89
    num_workers=args.num_workers,
90
91
92
93
94
)

##################
# Model Definition
if not args.use_gt:
95
96
97
98
99
100
101
102
103
104
105
106
    feature_dim = g.ndata["features"].shape[1]
    model = LANDER(
        feature_dim=feature_dim,
        nhid=args.hidden,
        num_conv=args.num_conv,
        dropout=args.dropout,
        use_GAT=args.gat,
        K=args.gat_k,
        balance=args.balance,
        use_cluster_feat=args.use_cluster_feat,
        use_focal_loss=args.use_focal_loss,
    )
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    model.load_state_dict(torch.load(args.model_filename))
    model = model.to(device)
    model.eval()

# number of edges added is the indicator for early stopping
num_edges_add_last_level = np.Inf
##################################
# Predict connectivity and density
for level in range(args.levels):
    if not args.use_gt:
        total_batches = len(test_loader)
        for batch, minibatch in enumerate(test_loader):
            input_nodes, sub_g, bipartites = minibatch
            sub_g = sub_g.to(device)
            bipartites = [b.to(device) for b in bipartites]
            with torch.no_grad():
                output_bipartite = model(bipartites)
            global_nid = output_bipartite.dstdata[dgl.NID]
125
126
127
128
129
130
131
            global_eid = output_bipartite.edata["global_eid"]
            g.ndata["pred_den"][global_nid] = output_bipartite.dstdata[
                "pred_den"
            ].to("cpu")
            g.edata["prob_conn"][global_eid] = output_bipartite.edata[
                "prob_conn"
            ].to("cpu")
132
133
            torch.cuda.empty_cache()
            if (batch + 1) % 10 == 0:
134
                print("Batch %d / %d for inference" % (batch, total_batches))
135

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    (
        new_pred_labels,
        peaks,
        global_edges,
        global_pred_labels,
        global_peaks,
    ) = decode(
        g,
        args.tau,
        args.threshold,
        args.use_gt,
        ids,
        global_edges,
        global_num_nodes,
        global_peaks,
    )
152
153
154
    ids = ids[peaks]
    new_global_edges_len = len(global_edges[0])
    num_edges_add_this_level = new_global_edges_len - global_edges_len
155
156
157
158
159
160
161
162
    if stop_iterating(
        level,
        args.levels,
        args.early_stop,
        num_edges_add_this_level,
        num_edges_add_last_level,
        args.knn_k,
    ):
163
164
165
166
167
        break
    global_edges_len = new_global_edges_len
    num_edges_add_last_level = num_edges_add_this_level

    # build new dataset
168
169
170
171
172
173
174
175
    features, labels, cluster_features = build_next_level(
        features,
        labels,
        peaks,
        global_features,
        global_pred_labels,
        global_peaks,
    )
176
    # After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
177
178
179
180
181
182
183
184
    dataset = LanderDataset(
        features=features,
        labels=labels,
        k=args.knn_k,
        levels=1,
        faiss_gpu=False,
        cluster_features=cluster_features,
    )
185
    g = dataset.gs[0]
186
187
    g.ndata["pred_den"] = torch.zeros((g.number_of_nodes()))
    g.edata["prob_conn"] = torch.zeros((g.number_of_edges(), 2))
188
    test_loader = dgl.dataloading.DataLoader(
189
190
191
        g,
        torch.arange(g.number_of_nodes()),
        sampler,
192
193
194
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
195
        num_workers=args.num_workers,
196
197
    )
evaluation(global_pred_labels, global_labels, args.metrics)