"sgl-kernel/vscode:/vscode.git/clone" did not exist on "bdda6c42eb856f86dc781970f79e550c9a162ca0"
sampler.py 13.6 KB
Newer Older
1
2
3
4
5
6
import math
import numpy as np
import scipy as sp
import dgl.backend as F
import dgl
import os
7
import sys
8
9
10
11
12
13
import pickle
import time

# This partitions a list of edges based on relations to make sure
# each partition has roughly the same number of edges and relations.
def RelationPartition(edges, n):
14
15
16
    heads, rels, tails = edges
    print('relation partition {} edges into {} parts'.format(len(heads), n))
    uniq, cnts = np.unique(rels, return_counts=True)
17
18
19
20
21
22
23
    idx = np.flip(np.argsort(cnts))
    cnts = cnts[idx]
    uniq = uniq[idx]
    assert cnts[0] > cnts[-1]
    edge_cnts = np.zeros(shape=(n,), dtype=np.int64)
    rel_cnts = np.zeros(shape=(n,), dtype=np.int64)
    rel_dict = {}
24
25
26
    rel_parts = []
    for _ in range(n):
        rel_parts.append([])
27
28
29
30
31
    for i in range(len(cnts)):
        cnt = cnts[i]
        r = uniq[i]
        idx = np.argmin(edge_cnts)
        rel_dict[r] = idx
32
        rel_parts[idx].append(r)
33
34
35
36
        edge_cnts[idx] += cnt
        rel_cnts[idx] += 1
    for i, edge_cnt in enumerate(edge_cnts):
        print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i]))
37

38
    parts = []
39
    for i in range(n):
40
        parts.append([])
41
        rel_parts[i] = np.array(rel_parts[i])
42
43
44
45
46
47
    # let's store the edge index to each partition first.
    for i, r in enumerate(rels):
        part_idx = rel_dict[r]
        parts[part_idx].append(i)
    for i, part in enumerate(parts):
        parts[i] = np.array(part, dtype=np.int64)
48
    return parts, rel_parts
49
50

def RandomPartition(edges, n):
51
52
53
    heads, rels, tails = edges
    print('random partition {} edges into {} parts'.format(len(heads), n))
    idx = np.random.permutation(len(heads))
54
55
56
57
58
    part_size = int(math.ceil(len(idx) / n))
    parts = []
    for i in range(n):
        start = part_size * i
        end = min(part_size * (i + 1), len(idx))
59
60
        parts.append(idx[start:end])
        print('part {} has {} edges'.format(i, len(parts[-1])))
61
62
    return parts

63
64
def ConstructGraph(edges, n_entities, args):
    pickle_name = 'graph_train.pickle'
65
66
67
68
69
    if args.pickle_graph and os.path.exists(os.path.join(args.data_path, args.dataset, pickle_name)):
        with open(os.path.join(args.data_path, args.dataset, pickle_name), 'rb') as graph_file:
            g = pickle.load(graph_file)
            print('Load pickled graph.')
    else:
70
        src, etype_id, dst = edges
71
72
73
74
75
76
77
78
79
80
81
82
        coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), shape=[n_entities, n_entities])
        g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
        g.ndata['id'] = F.arange(0, g.number_of_nodes())
        g.edata['id'] = F.tensor(etype_id, F.int64)
        if args.pickle_graph:
            with open(os.path.join(args.data_path, args.dataset, pickle_name), 'wb') as graph_file:
                pickle.dump(g, graph_file)
    return g

class TrainDataset(object):
    def __init__(self, dataset, args, weighting=False, ranks=64):
        triples = dataset.train
83
84
85
        self.g = ConstructGraph(triples, dataset.n_entities, args)
        num_train = len(triples[0])
        print('|Train|:', num_train)
86
        if ranks > 1 and args.rel_part:
87
            self.edge_parts, self.rel_parts = RelationPartition(triples, ranks)
88
        elif ranks > 1:
89
            self.edge_parts = RandomPartition(triples, ranks)
90
        else:
91
92
93
94
95
96
97
98
99
            self.edge_parts = [np.arange(num_train)]
        if weighting:
            # TODO: weight to be added
            count = self.count_freq(triples)
            subsampling_weight = np.vectorize(
                lambda h, r, t: np.sqrt(1 / (count[(h, r)] + count[(t, -r - 1)]))
            )
            weight = subsampling_weight(src, etype_id, dst)
            self.g.edata['weight'] = F.zerocopy_from_numpy(weight)
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

    def count_freq(self, triples, start=4):
        count = {}
        for head, rel, tail in triples:
            if (head, rel) not in count:
                count[(head, rel)] = start
            else:
                count[(head, rel)] += 1

            if (tail, -rel - 1) not in count:
                count[(tail, -rel - 1)] = start
            else:
                count[(tail, -rel - 1)] += 1
        return count

115
    def create_sampler(self, batch_size, neg_sample_size=2, neg_chunk_size=None, mode='head', num_workers=5,
116
117
                       shuffle=True, exclude_positive=False, rank=0):
        EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
118
119
        return EdgeSampler(self.g,
                           seed_edges=F.tensor(self.edge_parts[rank]),
120
121
                           batch_size=batch_size,
                           neg_sample_size=neg_sample_size,
122
                           chunk_size=neg_chunk_size,
123
124
125
126
127
128
                           negative_mode=mode,
                           num_workers=num_workers,
                           shuffle=shuffle,
                           exclude_positive=exclude_positive,
                           return_false_neg=False)

129

130
class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
131
132
    def __init__(self, subg, num_chunks, chunk_size,
                 neg_sample_size, neg_head):
133
        super(ChunkNegEdgeSubgraph, self).__init__(subg._parent, subg.sgi)
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        self.subg = subg
        self.num_chunks = num_chunks
        self.chunk_size = chunk_size
        self.neg_sample_size = neg_sample_size
        self.neg_head = neg_head

    @property
    def head_nid(self):
        return self.subg.head_nid

    @property
    def tail_nid(self):
        return self.subg.tail_nid


149
150
151
152
153
# KG models need to know the number of chunks, the chunk size and negative sample size
# of a negative subgraph to perform the computation more efficiently.
# This function tries to infer all of these information of the negative subgraph
# and create a wrapper class that contains all of the information.
def create_neg_subgraph(pos_g, neg_g, chunk_size, is_chunked, neg_head, num_nodes):
154
155
156
157
158
159
160
161
    assert neg_g.number_of_edges() % pos_g.number_of_edges() == 0
    neg_sample_size = int(neg_g.number_of_edges() / pos_g.number_of_edges())
    # We use all nodes to create negative edges. Regardless of the sampling algorithm,
    # we can always view the subgraph with one chunk.
    if (neg_head and len(neg_g.head_nid) == num_nodes) \
       or (not neg_head and len(neg_g.tail_nid) == num_nodes):
        num_chunks = 1
        chunk_size = pos_g.number_of_edges()
162
163
    elif is_chunked:
        if pos_g.number_of_edges() < chunk_size:
164
165
166
167
            num_chunks = 1
            chunk_size = pos_g.number_of_edges()
        else:
            # This is probably the last batch. Let's ignore it.
168
            if pos_g.number_of_edges() % chunk_size > 0:
169
                return None
170
171
172
            num_chunks = int(pos_g.number_of_edges()/ chunk_size)
        assert num_chunks * chunk_size == pos_g.number_of_edges()
        assert num_chunks * neg_sample_size * chunk_size == neg_g.number_of_edges()
173
174
175
    else:
        num_chunks = pos_g.number_of_edges()
        chunk_size = 1
176
177
    return ChunkNegEdgeSubgraph(neg_g, num_chunks, chunk_size,
                                neg_sample_size, neg_head)
178
179

class EvalSampler(object):
180
    def __init__(self, g, edges, batch_size, neg_sample_size, neg_chunk_size, mode, num_workers,
181
                 filter_false_neg):
182
183
184
185
186
        EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
        self.sampler = EdgeSampler(g,
                                   batch_size=batch_size,
                                   seed_edges=edges,
                                   neg_sample_size=neg_sample_size,
187
                                   chunk_size=neg_chunk_size,
188
189
190
191
192
                                   negative_mode=mode,
                                   num_workers=num_workers,
                                   shuffle=False,
                                   exclude_positive=False,
                                   relations=g.edata['id'],
193
                                   return_false_neg=filter_false_neg)
194
195
196
197
        self.sampler_iter = iter(self.sampler)
        self.mode = mode
        self.neg_head = 'head' in mode
        self.g = g
198
        self.filter_false_neg = filter_false_neg
199
        self.neg_chunk_size = neg_chunk_size
200
201
202
203
204
205
206

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            pos_g, neg_g = next(self.sampler_iter)
207
208
            if self.filter_false_neg:
                neg_positive = neg_g.edata['false_neg']
209
            neg_g = create_neg_subgraph(pos_g, neg_g, self.neg_chunk_size, 'chunk' in self.mode,
210
211
212
213
214
215
                                        self.neg_head, self.g.number_of_nodes())
            if neg_g is not None:
                break

        pos_g.copy_from_parent()
        neg_g.copy_from_parent()
216
217
        if self.filter_false_neg:
            neg_g.edata['bias'] = F.astype(-neg_positive, F.float32)
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        return pos_g, neg_g

    def reset(self):
        self.sampler_iter = iter(self.sampler)
        return self

class EvalDataset(object):
    def __init__(self, dataset, args):
        pickle_name = 'graph_all.pickle'
        if args.pickle_graph and os.path.exists(os.path.join(args.data_path, args.dataset, pickle_name)):
            with open(os.path.join(args.data_path, args.dataset, pickle_name), 'rb') as graph_file:
                g = pickle.load(graph_file)
                print('Load pickled graph.')
        else:
232
233
234
235
236
            src = np.concatenate((dataset.train[0], dataset.valid[0], dataset.test[0]))
            etype_id = np.concatenate((dataset.train[1], dataset.valid[1], dataset.test[1]))
            dst = np.concatenate((dataset.train[2], dataset.valid[2], dataset.test[2]))
            coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)),
                                       shape=[dataset.n_entities, dataset.n_entities])
237
238
239
240
241
242
243
244
            g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
            g.ndata['id'] = F.arange(0, g.number_of_nodes())
            g.edata['id'] = F.tensor(etype_id, F.int64)
            if args.pickle_graph:
                with open(os.path.join(args.data_path, args.dataset, pickle_name), 'wb') as graph_file:
                    pickle.dump(g, graph_file)
        self.g = g

245
246
247
        self.num_train = len(dataset.train[0])
        self.num_valid = len(dataset.valid[0])
        self.num_test = len(dataset.test[0])
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

        if args.eval_percent < 1:
            self.valid = np.random.randint(0, self.num_valid,
                    size=(int(self.num_valid * args.eval_percent),)) + self.num_train
        else:
            self.valid = np.arange(self.num_train, self.num_train + self.num_valid)
        print('|valid|:', len(self.valid))

        if args.eval_percent < 1:
            self.test = np.random.randint(0, self.num_test,
                    size=(int(self.num_test * args.eval_percent,)))
            self.test += self.num_train + self.num_valid
        else:
            self.test = np.arange(self.num_train + self.num_valid, self.g.number_of_edges())
        print('|test|:', len(self.test))

        self.num_valid = len(self.valid)
        self.num_test = len(self.test)

    def get_edges(self, eval_type):
        if eval_type == 'valid':
            return self.valid
        elif eval_type == 'test':
            return self.test
        else:
            raise Exception('get invalid type: ' + eval_type)

    def check(self, eval_type):
        edges = self.get_edges(eval_type)
        subg = self.g.edge_subgraph(edges)
        if eval_type == 'valid':
            data = self.valid
        elif eval_type == 'test':
            data = self.test

        subg.copy_from_parent()
        src, dst, eid = subg.all_edges('all', order='eid')
        src_id = subg.ndata['id'][src]
        dst_id = subg.ndata['id'][dst]
        etype = subg.edata['id'][eid]

        orig_src = np.array([t[0] for t in data])
        orig_etype = np.array([t[1] for t in data])
        orig_dst = np.array([t[2] for t in data])
        np.testing.assert_equal(F.asnumpy(src_id), orig_src)
        np.testing.assert_equal(F.asnumpy(dst_id), orig_dst)
        np.testing.assert_equal(F.asnumpy(etype), orig_etype)

296
    def create_sampler(self, eval_type, batch_size, neg_sample_size, neg_chunk_size,
297
                       filter_false_neg, mode='head', num_workers=5, rank=0, ranks=1):
298
299
300
301
        edges = self.get_edges(eval_type)
        beg = edges.shape[0] * rank // ranks
        end = min(edges.shape[0] * (rank + 1) // ranks, edges.shape[0])
        edges = edges[beg: end]
302
        return EvalSampler(self.g, edges, batch_size, neg_sample_size, neg_chunk_size,
303
                           mode, num_workers, filter_false_neg)
304
305

class NewBidirectionalOneShotIterator:
306
    def __init__(self, dataloader_head, dataloader_tail, neg_chunk_size, is_chunked, num_nodes):
307
308
        self.sampler_head = dataloader_head
        self.sampler_tail = dataloader_tail
309
        self.iterator_head = self.one_shot_iterator(dataloader_head, neg_chunk_size, is_chunked,
310
                                                    True, num_nodes)
311
        self.iterator_tail = self.one_shot_iterator(dataloader_tail, neg_chunk_size, is_chunked,
312
313
314
315
316
317
318
319
320
321
322
323
                                                    False, num_nodes)
        self.step = 0

    def __next__(self):
        self.step += 1
        if self.step % 2 == 0:
            pos_g, neg_g = next(self.iterator_head)
        else:
            pos_g, neg_g = next(self.iterator_tail)
        return pos_g, neg_g

    @staticmethod
324
    def one_shot_iterator(dataloader, neg_chunk_size, is_chunked, neg_head, num_nodes):
325
326
        while True:
            for pos_g, neg_g in dataloader:
327
328
                neg_g = create_neg_subgraph(pos_g, neg_g, neg_chunk_size, is_chunked,
                                            neg_head, num_nodes)
329
330
331
332
333
334
                if neg_g is None:
                    continue

                pos_g.copy_from_parent()
                neg_g.copy_from_parent()
                yield pos_g, neg_g