sampler.py 13.7 KB
Newer Older
1
2
3
4
5
6
import math
import numpy as np
import scipy as sp
import dgl.backend as F
import dgl
import os
7
import sys
8
9
10
11
12
13
import pickle
import time

# This partitions a list of edges based on relations to make sure
# each partition has roughly the same number of edges and relations.
def RelationPartition(edges, n):
14
15
16
    heads, rels, tails = edges
    print('relation partition {} edges into {} parts'.format(len(heads), n))
    uniq, cnts = np.unique(rels, return_counts=True)
17
18
19
20
21
22
23
    idx = np.flip(np.argsort(cnts))
    cnts = cnts[idx]
    uniq = uniq[idx]
    assert cnts[0] > cnts[-1]
    edge_cnts = np.zeros(shape=(n,), dtype=np.int64)
    rel_cnts = np.zeros(shape=(n,), dtype=np.int64)
    rel_dict = {}
24
25
26
    rel_parts = []
    for _ in range(n):
        rel_parts.append([])
27
28
29
30
31
    for i in range(len(cnts)):
        cnt = cnts[i]
        r = uniq[i]
        idx = np.argmin(edge_cnts)
        rel_dict[r] = idx
32
        rel_parts[idx].append(r)
33
34
35
36
        edge_cnts[idx] += cnt
        rel_cnts[idx] += 1
    for i, edge_cnt in enumerate(edge_cnts):
        print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i]))
37

38
    parts = []
39
    for i in range(n):
40
        parts.append([])
41
        rel_parts[i] = np.array(rel_parts[i])
42
43
44
45
46
47
    # let's store the edge index to each partition first.
    for i, r in enumerate(rels):
        part_idx = rel_dict[r]
        parts[part_idx].append(i)
    for i, part in enumerate(parts):
        parts[i] = np.array(part, dtype=np.int64)
48
    return parts, rel_parts
49
50

def RandomPartition(edges, n):
51
52
53
    heads, rels, tails = edges
    print('random partition {} edges into {} parts'.format(len(heads), n))
    idx = np.random.permutation(len(heads))
54
55
56
57
58
    part_size = int(math.ceil(len(idx) / n))
    parts = []
    for i in range(n):
        start = part_size * i
        end = min(part_size * (i + 1), len(idx))
59
60
        parts.append(idx[start:end])
        print('part {} has {} edges'.format(i, len(parts[-1])))
61
62
    return parts

63
64
def ConstructGraph(edges, n_entities, args):
    pickle_name = 'graph_train.pickle'
65
66
67
68
69
    if args.pickle_graph and os.path.exists(os.path.join(args.data_path, args.dataset, pickle_name)):
        with open(os.path.join(args.data_path, args.dataset, pickle_name), 'rb') as graph_file:
            g = pickle.load(graph_file)
            print('Load pickled graph.')
    else:
70
        src, etype_id, dst = edges
71
72
        coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), shape=[n_entities, n_entities])
        g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
73
        g.edata['tid'] = F.tensor(etype_id, F.int64)
74
75
76
77
78
79
80
81
        if args.pickle_graph:
            with open(os.path.join(args.data_path, args.dataset, pickle_name), 'wb') as graph_file:
                pickle.dump(g, graph_file)
    return g

class TrainDataset(object):
    def __init__(self, dataset, args, weighting=False, ranks=64):
        triples = dataset.train
82
83
84
        self.g = ConstructGraph(triples, dataset.n_entities, args)
        num_train = len(triples[0])
        print('|Train|:', num_train)
85
        if ranks > 1 and args.rel_part:
86
            self.edge_parts, self.rel_parts = RelationPartition(triples, ranks)
87
        elif ranks > 1:
88
            self.edge_parts = RandomPartition(triples, ranks)
89
        else:
90
91
92
93
94
95
96
97
98
            self.edge_parts = [np.arange(num_train)]
        if weighting:
            # TODO: weight to be added
            count = self.count_freq(triples)
            subsampling_weight = np.vectorize(
                lambda h, r, t: np.sqrt(1 / (count[(h, r)] + count[(t, -r - 1)]))
            )
            weight = subsampling_weight(src, etype_id, dst)
            self.g.edata['weight'] = F.zerocopy_from_numpy(weight)
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    def count_freq(self, triples, start=4):
        count = {}
        for head, rel, tail in triples:
            if (head, rel) not in count:
                count[(head, rel)] = start
            else:
                count[(head, rel)] += 1

            if (tail, -rel - 1) not in count:
                count[(tail, -rel - 1)] = start
            else:
                count[(tail, -rel - 1)] += 1
        return count

114
    def create_sampler(self, batch_size, neg_sample_size=2, neg_chunk_size=None, mode='head', num_workers=5,
115
116
                       shuffle=True, exclude_positive=False, rank=0):
        EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
117
118
        return EdgeSampler(self.g,
                           seed_edges=F.tensor(self.edge_parts[rank]),
119
120
                           batch_size=batch_size,
                           neg_sample_size=neg_sample_size,
121
                           chunk_size=neg_chunk_size,
122
123
124
125
126
127
                           negative_mode=mode,
                           num_workers=num_workers,
                           shuffle=shuffle,
                           exclude_positive=exclude_positive,
                           return_false_neg=False)

128

129
class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
130
131
    def __init__(self, subg, num_chunks, chunk_size,
                 neg_sample_size, neg_head):
132
        super(ChunkNegEdgeSubgraph, self).__init__(subg._parent, subg.sgi)
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        self.subg = subg
        self.num_chunks = num_chunks
        self.chunk_size = chunk_size
        self.neg_sample_size = neg_sample_size
        self.neg_head = neg_head

    @property
    def head_nid(self):
        return self.subg.head_nid

    @property
    def tail_nid(self):
        return self.subg.tail_nid


148
149
150
151
152
# KG models need to know the number of chunks, the chunk size and negative sample size
# of a negative subgraph to perform the computation more efficiently.
# This function tries to infer all of these information of the negative subgraph
# and create a wrapper class that contains all of the information.
def create_neg_subgraph(pos_g, neg_g, chunk_size, is_chunked, neg_head, num_nodes):
153
154
155
156
157
158
159
160
    assert neg_g.number_of_edges() % pos_g.number_of_edges() == 0
    neg_sample_size = int(neg_g.number_of_edges() / pos_g.number_of_edges())
    # We use all nodes to create negative edges. Regardless of the sampling algorithm,
    # we can always view the subgraph with one chunk.
    if (neg_head and len(neg_g.head_nid) == num_nodes) \
       or (not neg_head and len(neg_g.tail_nid) == num_nodes):
        num_chunks = 1
        chunk_size = pos_g.number_of_edges()
161
162
    elif is_chunked:
        if pos_g.number_of_edges() < chunk_size:
163
164
165
166
            num_chunks = 1
            chunk_size = pos_g.number_of_edges()
        else:
            # This is probably the last batch. Let's ignore it.
167
            if pos_g.number_of_edges() % chunk_size > 0:
168
                return None
169
170
171
            num_chunks = int(pos_g.number_of_edges()/ chunk_size)
        assert num_chunks * chunk_size == pos_g.number_of_edges()
        assert num_chunks * neg_sample_size * chunk_size == neg_g.number_of_edges()
172
173
174
    else:
        num_chunks = pos_g.number_of_edges()
        chunk_size = 1
175
176
    return ChunkNegEdgeSubgraph(neg_g, num_chunks, chunk_size,
                                neg_sample_size, neg_head)
177
178

class EvalSampler(object):
179
    def __init__(self, g, edges, batch_size, neg_sample_size, neg_chunk_size, mode, num_workers,
180
                 filter_false_neg):
181
182
183
184
185
        EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
        self.sampler = EdgeSampler(g,
                                   batch_size=batch_size,
                                   seed_edges=edges,
                                   neg_sample_size=neg_sample_size,
186
                                   chunk_size=neg_chunk_size,
187
188
189
190
                                   negative_mode=mode,
                                   num_workers=num_workers,
                                   shuffle=False,
                                   exclude_positive=False,
191
                                   relations=g.edata['tid'],
192
                                   return_false_neg=filter_false_neg)
193
194
195
196
        self.sampler_iter = iter(self.sampler)
        self.mode = mode
        self.neg_head = 'head' in mode
        self.g = g
197
        self.filter_false_neg = filter_false_neg
198
        self.neg_chunk_size = neg_chunk_size
199
200
201
202
203
204
205

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            pos_g, neg_g = next(self.sampler_iter)
206
207
            if self.filter_false_neg:
                neg_positive = neg_g.edata['false_neg']
208
            neg_g = create_neg_subgraph(pos_g, neg_g, self.neg_chunk_size, 'chunk' in self.mode,
209
210
211
212
                                        self.neg_head, self.g.number_of_nodes())
            if neg_g is not None:
                break

213
214
215
        pos_g.ndata['id'] = pos_g.parent_nid
        neg_g.ndata['id'] = neg_g.parent_nid
        pos_g.edata['id'] = pos_g._parent.edata['tid'][pos_g.parent_eid]
216
217
        if self.filter_false_neg:
            neg_g.edata['bias'] = F.astype(-neg_positive, F.float32)
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        return pos_g, neg_g

    def reset(self):
        self.sampler_iter = iter(self.sampler)
        return self

class EvalDataset(object):
    def __init__(self, dataset, args):
        pickle_name = 'graph_all.pickle'
        if args.pickle_graph and os.path.exists(os.path.join(args.data_path, args.dataset, pickle_name)):
            with open(os.path.join(args.data_path, args.dataset, pickle_name), 'rb') as graph_file:
                g = pickle.load(graph_file)
                print('Load pickled graph.')
        else:
232
233
234
235
236
            src = np.concatenate((dataset.train[0], dataset.valid[0], dataset.test[0]))
            etype_id = np.concatenate((dataset.train[1], dataset.valid[1], dataset.test[1]))
            dst = np.concatenate((dataset.train[2], dataset.valid[2], dataset.test[2]))
            coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)),
                                       shape=[dataset.n_entities, dataset.n_entities])
237
            g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
238
            g.edata['tid'] = F.tensor(etype_id, F.int64)
239
240
241
242
            if args.pickle_graph:
                with open(os.path.join(args.data_path, args.dataset, pickle_name), 'wb') as graph_file:
                    pickle.dump(g, graph_file)
        self.g = g
243
244
245
        self.num_train = len(dataset.train[0])
        self.num_valid = len(dataset.valid[0])
        self.num_test = len(dataset.test[0])
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

        if args.eval_percent < 1:
            self.valid = np.random.randint(0, self.num_valid,
                    size=(int(self.num_valid * args.eval_percent),)) + self.num_train
        else:
            self.valid = np.arange(self.num_train, self.num_train + self.num_valid)
        print('|valid|:', len(self.valid))

        if args.eval_percent < 1:
            self.test = np.random.randint(0, self.num_test,
                    size=(int(self.num_test * args.eval_percent,)))
            self.test += self.num_train + self.num_valid
        else:
            self.test = np.arange(self.num_train + self.num_valid, self.g.number_of_edges())
        print('|test|:', len(self.test))

        self.num_valid = len(self.valid)
        self.num_test = len(self.test)

    def get_edges(self, eval_type):
        if eval_type == 'valid':
            return self.valid
        elif eval_type == 'test':
            return self.test
        else:
            raise Exception('get invalid type: ' + eval_type)

    def check(self, eval_type):
        edges = self.get_edges(eval_type)
        subg = self.g.edge_subgraph(edges)
        if eval_type == 'valid':
            data = self.valid
        elif eval_type == 'test':
            data = self.test

        subg.copy_from_parent()
        src, dst, eid = subg.all_edges('all', order='eid')
        src_id = subg.ndata['id'][src]
        dst_id = subg.ndata['id'][dst]
        etype = subg.edata['id'][eid]

        orig_src = np.array([t[0] for t in data])
        orig_etype = np.array([t[1] for t in data])
        orig_dst = np.array([t[2] for t in data])
        np.testing.assert_equal(F.asnumpy(src_id), orig_src)
        np.testing.assert_equal(F.asnumpy(dst_id), orig_dst)
        np.testing.assert_equal(F.asnumpy(etype), orig_etype)

294
    def create_sampler(self, eval_type, batch_size, neg_sample_size, neg_chunk_size,
295
                       filter_false_neg, mode='head', num_workers=5, rank=0, ranks=1):
296
297
298
299
        edges = self.get_edges(eval_type)
        beg = edges.shape[0] * rank // ranks
        end = min(edges.shape[0] * (rank + 1) // ranks, edges.shape[0])
        edges = edges[beg: end]
300
        return EvalSampler(self.g, edges, batch_size, neg_sample_size, neg_chunk_size,
301
                           mode, num_workers, filter_false_neg)
302
303

class NewBidirectionalOneShotIterator:
304
    def __init__(self, dataloader_head, dataloader_tail, neg_chunk_size, is_chunked, num_nodes):
305
306
        self.sampler_head = dataloader_head
        self.sampler_tail = dataloader_tail
307
        self.iterator_head = self.one_shot_iterator(dataloader_head, neg_chunk_size, is_chunked,
308
                                                    True, num_nodes)
309
        self.iterator_tail = self.one_shot_iterator(dataloader_tail, neg_chunk_size, is_chunked,
310
311
312
313
314
315
316
317
318
319
320
321
                                                    False, num_nodes)
        self.step = 0

    def __next__(self):
        self.step += 1
        if self.step % 2 == 0:
            pos_g, neg_g = next(self.iterator_head)
        else:
            pos_g, neg_g = next(self.iterator_tail)
        return pos_g, neg_g

    @staticmethod
322
    def one_shot_iterator(dataloader, neg_chunk_size, is_chunked, neg_head, num_nodes):
323
324
        while True:
            for pos_g, neg_g in dataloader:
325
326
                neg_g = create_neg_subgraph(pos_g, neg_g, neg_chunk_size, is_chunked,
                                            neg_head, num_nodes)
327
328
329
                if neg_g is None:
                    continue

330
331
332
                pos_g.ndata['id'] = pos_g.parent_nid
                neg_g.ndata['id'] = neg_g.parent_nid
                pos_g.edata['id'] = pos_g._parent.edata['tid'][pos_g.parent_eid]
333
                yield pos_g, neg_g