sampler.py 3.32 KB
Newer Older
ziqiaomeng's avatar
ziqiaomeng committed
1
2
3
import numpy as np
import random
import time
4
5
6
7
import tqdm
import dgl
import sys
import os
ziqiaomeng's avatar
ziqiaomeng committed
8
9
10

num_walks_per_node = 1000
walk_length = 100
11
path = sys.argv[1]
ziqiaomeng's avatar
ziqiaomeng committed
12

13
14
15
16
17
18
19
20
21
22
def construct_graph():
    paper_ids = []
    paper_names = []
    author_ids = []
    author_names = []
    conf_ids = []
    conf_names = []
    f_3 = open(os.path.join(path, "id_author.txt"), encoding="ISO-8859-1")
    f_4 = open(os.path.join(path, "id_conf.txt"), encoding="ISO-8859-1")
    f_5 = open(os.path.join(path, "paper.txt"), encoding="ISO-8859-1")
ziqiaomeng's avatar
ziqiaomeng committed
23
24
25
26
    while True:
        z = f_3.readline()
        if not z:
            break
27
        z = z.strip().split()
ziqiaomeng's avatar
ziqiaomeng committed
28
        identity = int(z[0])
29
30
        author_ids.append(identity)
        author_names.append(z[1])
ziqiaomeng's avatar
ziqiaomeng committed
31
32
33
34
    while True:
        w = f_4.readline()
        if not w:
            break;
35
        w = w.strip().split()
ziqiaomeng's avatar
ziqiaomeng committed
36
        identity = int(w[0])
37
38
        conf_ids.append(identity)
        conf_names.append(w[1])
ziqiaomeng's avatar
ziqiaomeng committed
39
40
41
42
    while True:
        v = f_5.readline()
        if not v:
            break;
43
        v = v.strip().split()
ziqiaomeng's avatar
ziqiaomeng committed
44
        identity = int(v[0])
45
46
47
        paper_name = 'p' + ''.join(v[1:])
        paper_ids.append(identity)
        paper_names.append(paper_name)
ziqiaomeng's avatar
ziqiaomeng committed
48
49
50
51
    f_3.close()
    f_4.close()
    f_5.close()

52
53
54
55
56
57
58
59
60
61
    author_ids_invmap = {x: i for i, x in enumerate(author_ids)}
    conf_ids_invmap = {x: i for i, x in enumerate(conf_ids)}
    paper_ids_invmap = {x: i for i, x in enumerate(paper_ids)}

    paper_author_src = []
    paper_author_dst = []
    paper_conf_src = []
    paper_conf_dst = []
    f_1 = open(os.path.join(path, "paper_author.txt"), "r")
    f_2 = open(os.path.join(path, "paper_conf.txt"), "r")
ziqiaomeng's avatar
ziqiaomeng committed
62
63
64
65
    for x in f_1:
        x = x.split('\t')
        x[0] = int(x[0])
        x[1] = int(x[1].strip('\n'))
66
67
        paper_author_src.append(paper_ids_invmap[x[0]])
        paper_author_dst.append(author_ids_invmap[x[1]])
ziqiaomeng's avatar
ziqiaomeng committed
68
69
70
71
    for y in f_2:
        y = y.split('\t')
        y[0] = int(y[0])
        y[1] = int(y[1].strip('\n'))
72
73
        paper_conf_src.append(paper_ids_invmap[y[0]])
        paper_conf_dst.append(conf_ids_invmap[y[1]])
ziqiaomeng's avatar
ziqiaomeng committed
74
75
    f_1.close()
    f_2.close()
76
77
78
79
80
81
82

    pa = dgl.bipartite((paper_author_src, paper_author_dst), 'paper', 'pa', 'author')
    ap = dgl.bipartite((paper_author_dst, paper_author_src), 'author', 'ap', 'paper')
    pc = dgl.bipartite((paper_conf_src, paper_conf_dst), 'paper', 'pc', 'conf')
    cp = dgl.bipartite((paper_conf_dst, paper_conf_src), 'conf', 'cp', 'paper')
    hg = dgl.hetero_from_relations([pa, ap, pc, cp])
    return hg, author_names, conf_names, paper_names
ziqiaomeng's avatar
ziqiaomeng committed
83
84
85

#"conference - paper - Author - paper - conference" metapath sampling
def generate_metapath():
86
    output_path = open(os.path.join(path, "output_path.txt"), "w")
ziqiaomeng's avatar
ziqiaomeng committed
87
    count = 0
88
89
90
91
92
93
94
95
96
97
98
99
100

    hg, author_names, conf_names, paper_names = construct_graph()

    for conf_idx in tqdm.trange(hg.number_of_nodes('conf')):
        traces = dgl.contrib.sampling.metapath_random_walk(
                hg, ['cp', 'pa', 'ap', 'pc'] * walk_length, [conf_idx], num_walks_per_node)
        traces = traces[0]
        for trace in traces:
            tr = np.insert(trace.numpy(), 0, conf_idx)
            outline = ' '.join(
                    (conf_names if i % 4 == 0 else author_names)[tr[i]]
                    for i in range(0, len(tr), 2))  # skip paper
            print(outline, file=output_path)
ziqiaomeng's avatar
ziqiaomeng committed
101
102
103
104
105
    output_path.close()


if __name__ == "__main__":
    generate_metapath()