utils.py 8.46 KB
Newer Older
Linfang He's avatar
Linfang He committed
1
import argparse
2
3
import multiprocessing
import time
Linfang He's avatar
Linfang He committed
4
from collections import defaultdict
5
from functools import partial, reduce, wraps
Linfang He's avatar
Linfang He committed
6
7
8

import networkx as nx
import numpy as np
9
import torch
Linfang He's avatar
Linfang He committed
10
11
from gensim.models.keyedvectors import Vocab
from six import iteritems
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12
from sklearn.metrics import auc, f1_score, precision_recall_curve, roc_auc_score
13
14


Linfang He's avatar
Linfang He committed
15
16
17
def parse_args():
    parser = argparse.ArgumentParser()

18
19
20
21
22
23
24
25
26
    parser.add_argument(
        "--input", type=str, default="data/amazon", help="Input dataset path"
    )

    parser.add_argument(
        "--features", type=str, default=None, help="Input node features"
    )

    parser.add_argument(
27
28
29
30
        "--epoch",
        type=int,
        default=100,
        help="Number of epoch. Default is 100.",
31
32
33
34
35
36
37
38
39
40
    )

    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        help="Number of batch_size. Default is 64.",
    )

    parser.add_argument(
41
42
43
44
        "--eval-type",
        type=str,
        default="all",
        help="The edge type(s) for evaluation.",
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    )

    parser.add_argument(
        "--schema",
        type=str,
        default=None,
        help="The metapath schema (e.g., U-I-U,I-U-I).",
    )

    parser.add_argument(
        "--dimensions",
        type=int,
        default=200,
        help="Number of dimensions. Default is 200.",
    )

    parser.add_argument(
        "--edge-dim",
        type=int,
        default=10,
        help="Number of edge embedding dimensions. Default is 10.",
    )

    parser.add_argument(
        "--att-dim",
        type=int,
        default=20,
        help="Number of attention dimensions. Default is 20.",
    )

    parser.add_argument(
        "--walk-length",
        type=int,
        default=10,
        help="Length of walk per source. Default is 10.",
    )

    parser.add_argument(
        "--num-walks",
        type=int,
        default=20,
        help="Number of walks per source. Default is 20.",
    )

    parser.add_argument(
        "--window-size",
        type=int,
        default=5,
        help="Context size for optimization. Default is 5.",
    )

    parser.add_argument(
        "--negative-samples",
        type=int,
        default=5,
        help="Negative samples for optimization. Default is 5.",
    )

    parser.add_argument(
        "--neighbor-samples",
        type=int,
        default=10,
        help="Neighbor samples for aggregation. Default is 10.",
    )

    parser.add_argument(
111
112
113
114
        "--patience",
        type=int,
        default=5,
        help="Early stopping patience. Default is 5.",
115
116
117
    )

    parser.add_argument(
118
119
120
121
        "--gpu",
        type=str,
        default=None,
        help="Comma separated list of GPU device IDs.",
122
123
124
    )

    parser.add_argument(
125
126
127
128
        "--workers",
        type=int,
        default=4,
        help="Number of workers.",
129
130
    )

Linfang He's avatar
Linfang He committed
131
132
133
134
135
    return parser.parse_args()


# for each line, the data is [edge_type, node, node]
def load_training_data(f_name):
136
    print("We are loading data from:", f_name)
Linfang He's avatar
Linfang He committed
137
138
    edge_data_by_type = dict()
    all_nodes = list()
139
    with open(f_name, "r") as f:
Linfang He's avatar
Linfang He committed
140
        for line in f:
141
            words = line[:-1].split(" ")  # line[-1] == '\n'
Linfang He's avatar
Linfang He committed
142
143
144
145
146
147
148
            if words[0] not in edge_data_by_type:
                edge_data_by_type[words[0]] = list()
            x, y = words[1], words[2]
            edge_data_by_type[words[0]].append((x, y))
            all_nodes.append(x)
            all_nodes.append(y)
    all_nodes = list(set(all_nodes))
149
    print("Total training nodes: " + str(len(all_nodes)))
Linfang He's avatar
Linfang He committed
150
151
152
153
154
    return edge_data_by_type


# for each line, the data is [edge_type, node, node, true_or_false]
def load_testing_data(f_name):
155
    print("We are loading data from:", f_name)
Linfang He's avatar
Linfang He committed
156
157
158
159
    true_edge_data_by_type = dict()
    false_edge_data_by_type = dict()
    all_edges = list()
    all_nodes = list()
160
    with open(f_name, "r") as f:
Linfang He's avatar
Linfang He committed
161
        for line in f:
162
            words = line[:-1].split(" ")
Linfang He's avatar
Linfang He committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
            x, y = words[1], words[2]
            if int(words[3]) == 1:
                if words[0] not in true_edge_data_by_type:
                    true_edge_data_by_type[words[0]] = list()
                true_edge_data_by_type[words[0]].append((x, y))
            else:
                if words[0] not in false_edge_data_by_type:
                    false_edge_data_by_type[words[0]] = list()
                false_edge_data_by_type[words[0]].append((x, y))
            all_nodes.append(x)
            all_nodes.append(y)
    all_nodes = list(set(all_nodes))
    return true_edge_data_by_type, false_edge_data_by_type


def load_node_type(f_name):
179
    print("We are loading node type from:", f_name)
Linfang He's avatar
Linfang He committed
180
    node_type = {}
181
    with open(f_name, "r") as f:
Linfang He's avatar
Linfang He committed
182
183
184
185
186
187
        for line in f:
            items = line.strip().split()
            node_type[items[0]] = items[1]
    return node_type


188
189
190
191
192
193
194
195
196
197
198
199
200
201
def generate_pairs_parallel(walks, skip_window=None, layer_id=None):
    pairs = []
    for walk in walks:
        walk = walk.tolist()
        for i in range(len(walk)):
            for j in range(1, skip_window + 1):
                if i - j >= 0:
                    pairs.append((walk[i], walk[i - j], layer_id))
                if i + j < len(walk):
                    pairs.append((walk[i], walk[i + j], layer_id))
    return pairs


def generate_pairs(all_walks, window_size, num_workers):
Linfang He's avatar
Linfang He committed
202
    # for each node, choose the first neighbor and second neighbor of it to form pairs
203
204
205
206
207
208
    # Get all worker processes
    start_time = time.time()
    print("We are generating pairs with {} cores.".format(num_workers))

    # Start all worker processes
    pool = multiprocessing.Pool(processes=num_workers)
Linfang He's avatar
Linfang He committed
209
210
211
    pairs = []
    skip_window = window_size // 2
    for layer_id, walks in enumerate(all_walks):
212
213
214
215
216
217
218
219
220
221
        block_num = len(walks) // num_workers
        if block_num > 0:
            walks_list = [
                walks[i * block_num : min((i + 1) * block_num, len(walks))]
                for i in range(num_workers)
            ]
        else:
            walks_list = [walks]
        tmp_result = pool.map(
            partial(
222
223
224
                generate_pairs_parallel,
                skip_window=skip_window,
                layer_id=layer_id,
225
226
227
228
229
230
231
232
233
234
            ),
            walks_list,
        )
        pairs += reduce(lambda x, y: x + y, tmp_result)

    pool.close()
    end_time = time.time()
    print("Generate pairs end, use {}s.".format(end_time - start_time))
    return np.array([list(pair) for pair in set(pairs)])

Linfang He's avatar
Linfang He committed
235
236
237
238
239
240

def generate_vocab(network_data):
    nodes, index2word = [], []
    for edge_type in network_data:
        node1, node2 = zip(*network_data[edge_type])
        index2word = index2word + list(node1) + list(node2)
241

Linfang He's avatar
Linfang He committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    index2word = list(set(index2word))
    vocab = {}
    i = 0
    for word in index2word:
        vocab[word] = i
        i = i + 1

    for edge_type in network_data:
        node1, node2 = zip(*network_data[edge_type])
        tmp_nodes = list(set(list(node1) + list(node2)))
        tmp_nodes = [vocab[word] for word in tmp_nodes]
        nodes.append(tmp_nodes)

    return index2word, vocab, nodes


258
259
def get_score(local_model, edge):
    node1, node2 = str(edge[0]), str(edge[1])
Linfang He's avatar
Linfang He committed
260
261
262
    try:
        vector1 = local_model[node1]
        vector2 = local_model[node2]
263
264
265
        return np.dot(vector1, vector2) / (
            np.linalg.norm(vector1) * np.linalg.norm(vector2)
        )
Linfang He's avatar
Linfang He committed
266
267
268
269
    except Exception as e:
        pass


270
def evaluate(model, true_edges, false_edges, num_workers):
Linfang He's avatar
Linfang He committed
271
272
273
    true_list = list()
    prediction_list = list()
    true_num = 0
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

    # Start all worker processes
    pool = multiprocessing.Pool(processes=num_workers)
    tmp_true_score_list = pool.map(partial(get_score, model), true_edges)
    tmp_false_score_list = pool.map(partial(get_score, model), false_edges)
    pool.close()

    prediction_list += [
        tmp_score for tmp_score in tmp_true_score_list if tmp_score is not None
    ]
    true_num = len(prediction_list)
    true_list += [1] * true_num

    prediction_list += [
        tmp_score for tmp_score in tmp_false_score_list if tmp_score is not None
    ]
    true_list += [0] * (len(prediction_list) - true_num)
Linfang He's avatar
Linfang He committed
291
292
293
294
295
296
297
298
299
300
301
302
303

    sorted_pred = prediction_list[:]
    sorted_pred.sort()
    threshold = sorted_pred[-true_num]

    y_pred = np.zeros(len(prediction_list), dtype=np.int32)
    for i in range(len(prediction_list)):
        if prediction_list[i] >= threshold:
            y_pred[i] = 1

    y_true = np.array(true_list)
    y_scores = np.array(prediction_list)
    ps, rs, _ = precision_recall_curve(y_true, y_scores)
304
305
306
307
308
    return (
        roc_auc_score(y_true, y_scores),
        f1_score(y_true, y_pred),
        auc(rs, ps),
    )