reading_data.py 7.32 KB
Newer Older
1
import os
2
3
4
5
6
import numpy as np
import scipy.sparse as sp
import pickle
import torch
from torch.utils.data import DataLoader
7
from dgl.data.utils import download, _get_dgl_url, get_download_dir, extract_archive
8
9
10
import random
import time
import dgl
11

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from utils import shuffle_walks

def ReadTxtNet(file_path="", undirected=True):
    """ Read the txt network file. 
    Notations: The network is unweighted.

    Parameters
    ----------
    file_path str : path of network file
    undirected bool : whether the edges are undirected

    Return
    ------
    net dict : a dict recording the connections in the graph
    node2id dict : a dict mapping the nodes to their embedding indices 
    id2node dict : a dict mapping nodes embedding indices to the nodes
    """
29
30
31
32
33
34
35
36
37
    if file_path == 'youtube' or file_path == 'blog':
        name = file_path
        dir = get_download_dir()
        zip_file_path='{}/{}.zip'.format(dir, name)
        download(_get_dgl_url(os.path.join('dataset/DeepWalk/', '{}.zip'.format(file_path))), path=zip_file_path)
        extract_archive(zip_file_path,
                        '{}/{}'.format(dir, name))
        file_path = "{}/{}/{}-net.txt".format(dir, name, name)

38
39
40
41
42
43
    node2id = {}
    id2node = {}
    cid = 0

    src = []
    dst = []
44
    weight = []
45
46
47
    net = {}
    with open(file_path, "r") as f:
        for line in f.readlines():
48
49
50
51
52
53
54
            tup = list(map(int, line.strip().split(" ")))
            assert len(tup) in [2, 3], "The format of network file is unrecognizable."
            if len(tup) == 3:
                n1, n2, w = tup
            elif len(tup) == 2:
                n1, n2 = tup
                w = 1
55
56
57
58
59
60
61
62
63
64
65
66
            if n1 not in node2id:
                node2id[n1] = cid
                id2node[cid] = n1
                cid += 1
            if n2 not in node2id:
                node2id[n2] = cid
                id2node[cid] = n2
                cid += 1

            n1 = node2id[n1]
            n2 = node2id[n2]
            if n1 not in net:
67
                net[n1] = {n2: w}
68
69
                src.append(n1)
                dst.append(n2)
70
                weight.append(w)
71
            elif n2 not in net[n1]:
72
                net[n1][n2] = w
73
74
                src.append(n1)
                dst.append(n2)
75
                weight.append(w)
76
77
78
            
            if undirected:
                if n2 not in net:
79
                    net[n2] = {n1: w}
80
81
                    src.append(n2)
                    dst.append(n1)
82
                    weight.append(w)
83
                elif n1 not in net[n2]:
84
                    net[n2][n1] = w
85
86
                    src.append(n2)
                    dst.append(n1)
87
                    weight.append(w)
88
89
90
91
92
93

    print("node num: %d" % len(net))
    print("edge num: %d" % len(src))
    assert max(net.keys()) == len(net) - 1, "error reading net, quit"

    sm = sp.coo_matrix(
94
        (np.array(weight), (src, dst)),
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        dtype=np.float32)

    return net, node2id, id2node, sm

def net2graph(net_sm):
    """ Transform the network to DGL graph

    Return 
    ------
    G DGLGraph : graph by DGL
    """
    start = time.time()
    G = dgl.DGLGraph(net_sm)
    end = time.time()
    t = end - start
    print("Building DGLGraph in %.2fs" % t)
    return G

113
114
115
116
117
118
119
120
121
122
123
124
def make_undirected(G):
    G.readonly(False)
    G.add_edges(G.edges()[1], G.edges()[0])
    return G

def find_connected_nodes(G):
    nodes = []
    for n in G.nodes():
        if G.out_degree(n) > 0:
            nodes.append(n.item())
    return nodes

125
126
127
128
class DeepwalkDataset:
    def __init__(self, 
            net_file,
            map_file,
129
130
131
132
            walk_length,
            window_size,
            num_walks,
            batch_size,
133
            negative=5,
134
            gpus=[0],
135
            fast_neg=True,
136
137
            ogbl_name="",
            load_from_ogbl=False,
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
            ):
        """ This class has the following functions:
        1. Transform the txt network file into DGL graph;
        2. Generate random walk sequences for the trainer;
        3. Provide the negative table if the user hopes to sample negative
        nodes according to nodes' degrees;

        Parameter
        ---------
        net_file str : path of the txt network file
        walk_length int : number of nodes in a sequence
        window_size int : context window size
        num_walks int : number of walks for each node
        batch_size int : number of node sequences in each batch
        negative int : negative samples for each positve node pair
        fast_neg bool : whether do negative sampling inside a batch
        """
        self.walk_length = walk_length
        self.window_size = window_size
        self.num_walks = num_walks
        self.batch_size = batch_size
        self.negative = negative
160
        self.num_procs = len(gpus)
161
        self.fast_neg = fast_neg
162
163
164
165
166
167
168
169
170
171
172
173

        if load_from_ogbl:
            assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training (CUDA error)."
            from load_dataset import load_from_ogbl_with_name
            self.G = load_from_ogbl_with_name(ogbl_name)
            self.G = make_undirected(self.G)
        else:
            self.net, self.node2id, self.id2node, self.sm = ReadTxtNet(net_file)
            self.save_mapping(map_file)
            self.G = net2graph(self.sm)

        self.num_nodes = self.G.number_of_nodes()
174
175
176

        # random walk seeds
        start = time.time()
177
178
179
180
181
182
183
184
        self.valid_seeds = find_connected_nodes(self.G)
        if len(self.valid_seeds) != self.num_nodes:
            print("WARNING: The node ids are not serial. Some nodes are invalid.")
        
        seeds = torch.cat([torch.LongTensor(self.valid_seeds)] * num_walks)
        self.seeds = torch.split(shuffle_walks(seeds), 
            int(np.ceil(len(self.valid_seeds) * self.num_walks / self.num_procs)), 
            0)
185
186
187
188
189
190
        end = time.time()
        t = end - start
        print("%d seeds in %.2fs" % (len(seeds), t))

        # negative table for true negative sampling
        if not fast_neg:
191
            node_degree = np.array(list(map(lambda x: self.G.out_degree(x), self.valid_seeds)))
192
193
194
195
            node_degree = np.power(node_degree, 0.75)
            node_degree /= np.sum(node_degree)
            node_degree = np.array(node_degree * 1e8, dtype=np.int)
            self.neg_table = []
196
197
            
            for idx, node in enumerate(self.valid_seeds):
198
199
200
201
202
                self.neg_table += [node] * node_degree[idx]
            self.neg_table_size = len(self.neg_table)
            self.neg_table = np.array(self.neg_table, dtype=np.long)
            del node_degree

203
204
205
    def create_sampler(self, i):
        """ create random walk sampler """
        return DeepwalkSampler(self.G, self.seeds[i], self.walk_length)
206
207

    def save_mapping(self, map_file):
208
        """ save the mapping dict that maps node IDs to embedding indices """
209
210
211
212
213
        with open(map_file, "wb") as f:
            pickle.dump(self.node2id, f)

class DeepwalkSampler(object):
    def __init__(self, G, seeds, walk_length):
214
215
216
217
218
219
220
221
        """ random walk sampler 
        
        Parameter
        ---------
        G dgl.Graph : the input graph
        seeds torch.LongTensor : starting nodes
        walk_length int : walk length
        """
222
223
224
225
226
227
228
        self.G = G
        self.seeds = seeds
        self.walk_length = walk_length
    
    def sample(self, seeds):
        walks = dgl.contrib.sampling.random_walk(self.G, seeds, 
            1, self.walk_length-1)
229
        return walks