test_integration.py 19.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import dgl
import dgl.graphbolt as gb
import dgl.sparse as dglsp
import torch


def test_integration_link_prediction():
    torch.manual_seed(926)

    indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])
    indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])

    matrix_a = dglsp.from_csc(indptr, indices)
    node_pairs = torch.t(torch.stack(matrix_a.coo()))
    node_feature_data = torch.tensor(
        [
            [0.9634, 0.2294],
            [0.6172, 0.7865],
            [0.2109, 0.1089],
            [0.8672, 0.2276],
            [0.5503, 0.8223],
            [0.5160, 0.2486],
        ]
    )
    edge_feature_data = torch.tensor(
        [
            [0.5123, 0.1709, 0.6150],
            [0.1476, 0.1902, 0.1314],
            [0.2582, 0.5203, 0.6228],
            [0.3708, 0.7631, 0.2683],
            [0.2126, 0.7878, 0.7225],
            [0.7885, 0.3414, 0.5485],
            [0.4088, 0.8200, 0.1851],
            [0.0056, 0.9469, 0.4432],
            [0.8972, 0.7511, 0.3617],
            [0.5773, 0.2199, 0.3366],
        ]
    )

    item_set = gb.ItemSet(node_pairs, names="node_pairs")
41
    graph = gb.fused_csc_sampling_graph(indptr, indices)
42
43
44
45
46
47
48
49
50

    node_feature = gb.TorchBasedFeature(node_feature_data)
    edge_feature = gb.TorchBasedFeature(edge_feature_data)
    features = {
        ("node", None, "feat"): node_feature,
        ("edge", None, "feat"): edge_feature,
    }
    feature_store = gb.BasicFeatureStore(features)
    datapipe = gb.ItemSampler(item_set, batch_size=4)
51
    datapipe = datapipe.sample_uniform_negative(graph, 2)
52
53
54
55
56
57
    fanouts = torch.LongTensor([1])
    datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(
        feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
    )
58
    dataloader = gb.DataLoader(
59
60
61
62
        datapipe,
    )
    expected = [
        str(
63
64
65
66
67
68
69
70
71
72
73
74
            """MiniBatch(seeds=tensor([[5, 1],
                        [3, 2],
                        [3, 2],
                        [3, 3],
                        [5, 0],
                        [5, 0],
                        [3, 3],
                        [3, 0],
                        [3, 5],
                        [3, 3],
                        [3, 3],
                        [3, 4]]),
75
          seed_nodes=None,
76
77
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 2, 2, 2, 3], dtype=torch.int32),
                                                                         indices=tensor([0, 5, 4], dtype=torch.int32),
78
                                                           ),
79
                                               original_row_node_ids=tensor([5, 1, 3, 2, 0, 4]),
80
                                               original_edge_ids=None,
81
                                               original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
82
                            ),
83
84
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([5, 4], dtype=torch.int32),
85
                                                           ),
86
                                               original_row_node_ids=tensor([5, 1, 3, 2, 0, 4]),
87
                                               original_edge_ids=None,
88
                                               original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
89
                            )],
90
91
92
          positive_node_pairs=None,
          node_pairs_with_labels=None,
          node_pairs=None,
93
94
          node_features={'feat': tensor([[0.5160, 0.2486],
                                [0.6172, 0.7865],
95
                                [0.8672, 0.2276],
96
97
98
                                [0.2109, 0.1089],
                                [0.9634, 0.2294],
                                [0.5503, 0.8223]])},
99
          negative_srcs=None,
100
101
102
103
104
          negative_node_pairs=None,
          negative_dsts=None,
          labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
          input_nodes=tensor([5, 1, 3, 2, 0, 4]),
          indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
105
106
          edge_features=[{},
                        {}],
107
108
109
110
111
112
113
114
115
116
117
118
119
          compacted_seeds=tensor([[0, 1],
                                  [2, 3],
                                  [2, 3],
                                  [2, 2],
                                  [0, 4],
                                  [0, 4],
                                  [2, 2],
                                  [2, 4],
                                  [2, 0],
                                  [2, 2],
                                  [2, 2],
                                  [2, 5]]),
          compacted_node_pairs=None,
120
          compacted_negative_srcs=None,
121
122
          compacted_negative_dsts=None,
          blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
123
                 Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)],
124
       )"""
125
126
        ),
        str(
127
128
129
130
131
132
133
134
135
136
137
138
            """MiniBatch(seeds=tensor([[3, 3],
                        [4, 3],
                        [4, 4],
                        [0, 4],
                        [3, 1],
                        [3, 5],
                        [4, 2],
                        [4, 5],
                        [4, 4],
                        [4, 3],
                        [0, 1],
                        [0, 5]]),
139
          seed_nodes=None,
140
141
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([4, 0], dtype=torch.int32),
142
                                                           ),
143
                                               original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
144
                                               original_edge_ids=None,
145
                                               original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
146
                            ),
147
148
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3], dtype=torch.int32),
                                                                         indices=tensor([4, 4, 0], dtype=torch.int32),
149
                                                           ),
150
                                               original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
151
                                               original_edge_ids=None,
152
                                               original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
153
                            )],
154
155
156
          positive_node_pairs=None,
          node_pairs_with_labels=None,
          node_pairs=None,
157
158
159
          node_features={'feat': tensor([[0.8672, 0.2276],
                                [0.5503, 0.8223],
                                [0.9634, 0.2294],
160
                                [0.6172, 0.7865],
161
                                [0.5160, 0.2486],
162
163
                                [0.2109, 0.1089]])},
          negative_srcs=None,
164
165
166
          negative_node_pairs=None,
          negative_dsts=None,
          labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
167
          input_nodes=tensor([3, 4, 0, 1, 5, 2]),
168
          indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
169
170
          edge_features=[{},
                        {}],
171
172
173
174
175
176
177
178
179
180
181
182
183
          compacted_seeds=tensor([[0, 0],
                                  [1, 0],
                                  [1, 1],
                                  [2, 1],
                                  [0, 3],
                                  [0, 4],
                                  [1, 5],
                                  [1, 4],
                                  [1, 1],
                                  [1, 0],
                                  [2, 3],
                                  [2, 4]]),
          compacted_node_pairs=None,
184
          compacted_negative_srcs=None,
185
186
          compacted_negative_dsts=None,
          blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
187
                 Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)],
188
       )"""
189
190
        ),
        str(
191
192
193
194
195
196
            """MiniBatch(seeds=tensor([[5, 5],
                        [4, 5],
                        [5, 0],
                        [5, 4],
                        [4, 0],
                        [4, 1]]),
197
          seed_nodes=None,
198
199
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([1, 0], dtype=torch.int32),
200
                                                           ),
201
                                               original_row_node_ids=tensor([5, 4, 0, 1]),
202
                                               original_edge_ids=None,
203
                                               original_column_node_ids=tensor([5, 4, 0, 1]),
204
                            ),
205
206
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([1, 0], dtype=torch.int32),
207
                                                           ),
208
                                               original_row_node_ids=tensor([5, 4, 0, 1]),
209
                                               original_edge_ids=None,
210
                                               original_column_node_ids=tensor([5, 4, 0, 1]),
211
                            )],
212
213
214
          positive_node_pairs=None,
          node_pairs_with_labels=None,
          node_pairs=None,
215
          node_features={'feat': tensor([[0.5160, 0.2486],
216
217
218
219
                                [0.5503, 0.8223],
                                [0.9634, 0.2294],
                                [0.6172, 0.7865]])},
          negative_srcs=None,
220
221
222
          negative_node_pairs=None,
          negative_dsts=None,
          labels=tensor([1., 1., 0., 0., 0., 0.]),
223
          input_nodes=tensor([5, 4, 0, 1]),
224
          indexes=tensor([0, 1, 0, 0, 1, 1]),
225
226
          edge_features=[{},
                        {}],
227
228
229
230
231
232
233
          compacted_seeds=tensor([[0, 0],
                                  [1, 0],
                                  [0, 2],
                                  [0, 1],
                                  [1, 2],
                                  [1, 3]]),
          compacted_node_pairs=None,
234
          compacted_negative_srcs=None,
235
          compacted_negative_dsts=None,
236
237
          blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2),
                 Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)],
238
       )"""
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        ),
    ]
    for step, data in enumerate(dataloader):
        assert expected[step] == str(data), print(data)


def test_integration_node_classification():
    torch.manual_seed(926)

    indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])
    indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])

    matrix_a = dglsp.from_csc(indptr, indices)
    node_pairs = torch.t(torch.stack(matrix_a.coo()))
    node_feature_data = torch.tensor(
        [
            [0.9634, 0.2294],
            [0.6172, 0.7865],
            [0.2109, 0.1089],
            [0.8672, 0.2276],
            [0.5503, 0.8223],
            [0.5160, 0.2486],
        ]
    )
    edge_feature_data = torch.tensor(
        [
            [0.5123, 0.1709, 0.6150],
            [0.1476, 0.1902, 0.1314],
            [0.2582, 0.5203, 0.6228],
            [0.3708, 0.7631, 0.2683],
            [0.2126, 0.7878, 0.7225],
            [0.7885, 0.3414, 0.5485],
            [0.4088, 0.8200, 0.1851],
            [0.0056, 0.9469, 0.4432],
            [0.8972, 0.7511, 0.3617],
            [0.5773, 0.2199, 0.3366],
        ]
    )

    item_set = gb.ItemSet(node_pairs, names="node_pairs")
279
    graph = gb.fused_csc_sampling_graph(indptr, indices)
280
281
282
283
284
285
286
287
288
289
290
291
292
293

    node_feature = gb.TorchBasedFeature(node_feature_data)
    edge_feature = gb.TorchBasedFeature(edge_feature_data)
    features = {
        ("node", None, "feat"): node_feature,
        ("edge", None, "feat"): edge_feature,
    }
    feature_store = gb.BasicFeatureStore(features)
    datapipe = gb.ItemSampler(item_set, batch_size=4)
    fanouts = torch.LongTensor([1])
    datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)
    datapipe = datapipe.fetch_feature(
        feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
    )
294
    dataloader = gb.DataLoader(
295
296
297
298
        datapipe,
    )
    expected = [
        str(
299
300
301
302
            """MiniBatch(seeds=tensor([[5, 1],
                        [3, 2],
                        [3, 2],
                        [3, 3]]),
303
          seed_nodes=None,
304
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
305
                                                                         indices=tensor([4, 0, 2, 2], dtype=torch.int32),
306
                                                           ),
307
                                               original_row_node_ids=tensor([5, 1, 3, 2, 4]),
308
                                               original_edge_ids=None,
309
                                               original_column_node_ids=tensor([5, 1, 3, 2]),
310
                            ),
311
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
312
                                                                         indices=tensor([0, 0, 2, 2], dtype=torch.int32),
313
                                                           ),
314
                                               original_row_node_ids=tensor([5, 1, 3, 2]),
315
                                               original_edge_ids=None,
316
                                               original_column_node_ids=tensor([5, 1, 3, 2]),
317
                            )],
318
          positive_node_pairs=None,
319
          node_pairs_with_labels=None,
320
          node_pairs=None,
321
322
          node_features={'feat': tensor([[0.5160, 0.2486],
                                [0.6172, 0.7865],
323
                                [0.8672, 0.2276],
324
325
326
327
328
329
                                [0.2109, 0.1089],
                                [0.5503, 0.8223]])},
          negative_srcs=None,
          negative_node_pairs=None,
          negative_dsts=None,
          labels=None,
330
          input_nodes=tensor([5, 1, 3, 2, 4]),
331
          indexes=None,
332
333
          edge_features=[{},
                        {}],
334
335
336
337
338
          compacted_seeds=tensor([[0, 1],
                                  [2, 3],
                                  [2, 3],
                                  [2, 2]]),
          compacted_node_pairs=None,
339
340
341
342
343
          compacted_negative_srcs=None,
          compacted_negative_dsts=None,
          blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
                 Block(num_src_nodes=4, num_dst_nodes=4, num_edges=4)],
       )"""
344
345
        ),
        str(
346
347
348
349
            """MiniBatch(seeds=tensor([[3, 3],
                        [4, 3],
                        [4, 4],
                        [0, 4]]),
350
          seed_nodes=None,
351
352
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
                                                                         indices=tensor([0, 2], dtype=torch.int32),
353
354
                                                           ),
                                               original_row_node_ids=tensor([3, 4, 0]),
355
356
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([3, 4, 0]),
357
                            ),
358
359
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
                                                                         indices=tensor([0, 2], dtype=torch.int32),
360
361
                                                           ),
                                               original_row_node_ids=tensor([3, 4, 0]),
362
363
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([3, 4, 0]),
364
                            )],
365
          positive_node_pairs=None,
366
          node_pairs_with_labels=None,
367
          node_pairs=None,
368
369
370
371
372
373
374
375
          node_features={'feat': tensor([[0.8672, 0.2276],
                                [0.5503, 0.8223],
                                [0.9634, 0.2294]])},
          negative_srcs=None,
          negative_node_pairs=None,
          negative_dsts=None,
          labels=None,
          input_nodes=tensor([3, 4, 0]),
376
          indexes=None,
377
378
          edge_features=[{},
                        {}],
379
380
381
382
383
          compacted_seeds=tensor([[0, 0],
                                  [1, 0],
                                  [1, 1],
                                  [2, 1]]),
          compacted_node_pairs=None,
384
385
386
387
388
          compacted_negative_srcs=None,
          compacted_negative_dsts=None,
          blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2),
                 Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2)],
       )"""
389
390
        ),
        str(
391
392
            """MiniBatch(seeds=tensor([[5, 5],
                        [4, 5]]),
393
          seed_nodes=None,
394
395
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([0, 2], dtype=torch.int32),
396
397
                                                           ),
                                               original_row_node_ids=tensor([5, 4, 0]),
398
399
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 4]),
400
                            ),
401
402
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([1, 1], dtype=torch.int32),
403
404
                                                           ),
                                               original_row_node_ids=tensor([5, 4]),
405
406
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 4]),
407
                            )],
408
          positive_node_pairs=None,
409
          node_pairs_with_labels=None,
410
          node_pairs=None,
411
412
413
414
415
416
417
418
          node_features={'feat': tensor([[0.5160, 0.2486],
                                [0.5503, 0.8223],
                                [0.9634, 0.2294]])},
          negative_srcs=None,
          negative_node_pairs=None,
          negative_dsts=None,
          labels=None,
          input_nodes=tensor([5, 4, 0]),
419
          indexes=None,
420
421
          edge_features=[{},
                        {}],
422
423
424
          compacted_seeds=tensor([[0, 0],
                                  [1, 0]]),
          compacted_node_pairs=None,
425
426
427
428
429
          compacted_negative_srcs=None,
          compacted_negative_dsts=None,
          blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
                 Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
       )"""
430
431
432
433
        ),
    ]
    for step, data in enumerate(dataloader):
        assert expected[step] == str(data), print(data)