test_integration.py 19.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import dgl
import dgl.graphbolt as gb
import dgl.sparse as dglsp
import torch


def test_integration_link_prediction():
    torch.manual_seed(926)

    indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])
    indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])

    matrix_a = dglsp.from_csc(indptr, indices)
    node_pairs = torch.t(torch.stack(matrix_a.coo()))
    node_feature_data = torch.tensor(
        [
            [0.9634, 0.2294],
            [0.6172, 0.7865],
            [0.2109, 0.1089],
            [0.8672, 0.2276],
            [0.5503, 0.8223],
            [0.5160, 0.2486],
        ]
    )
    edge_feature_data = torch.tensor(
        [
            [0.5123, 0.1709, 0.6150],
            [0.1476, 0.1902, 0.1314],
            [0.2582, 0.5203, 0.6228],
            [0.3708, 0.7631, 0.2683],
            [0.2126, 0.7878, 0.7225],
            [0.7885, 0.3414, 0.5485],
            [0.4088, 0.8200, 0.1851],
            [0.0056, 0.9469, 0.4432],
            [0.8972, 0.7511, 0.3617],
            [0.5773, 0.2199, 0.3366],
        ]
    )

    item_set = gb.ItemSet(node_pairs, names="node_pairs")
41
    graph = gb.fused_csc_sampling_graph(indptr, indices)
42
43
44
45
46
47
48
49
50

    node_feature = gb.TorchBasedFeature(node_feature_data)
    edge_feature = gb.TorchBasedFeature(edge_feature_data)
    features = {
        ("node", None, "feat"): node_feature,
        ("edge", None, "feat"): edge_feature,
    }
    feature_store = gb.BasicFeatureStore(features)
    datapipe = gb.ItemSampler(item_set, batch_size=4)
51
    datapipe = datapipe.sample_uniform_negative(graph, 2)
52
53
54
55
56
57
    fanouts = torch.LongTensor([1])
    datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(
        feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
    )
58
    dataloader = gb.DataLoader(
59
60
61
62
        datapipe,
    )
    expected = [
        str(
63
64
            """MiniBatch(seeds=None,
          seed_nodes=None,
65
66
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([0, 4], dtype=torch.int32),
67
68
                                                           ),
                                               original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
69
70
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
71
                            ),
72
73
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([5, 4], dtype=torch.int32),
74
75
                                                           ),
                                               original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
76
                                               original_edge_ids=None,
77
                                               original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
78
79
80
                            )],
          positive_node_pairs=(tensor([0, 1, 1, 1]),
                              tensor([2, 3, 3, 1])),
81
82
          node_pairs_with_labels=((tensor([0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1]), tensor([2, 3, 3, 1, 4, 4, 1, 4, 0, 1, 1, 5])),
                                 tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])),
83
84
85
86
87
88
89
90
          node_pairs=(tensor([5, 3, 3, 3]),
                     tensor([1, 2, 2, 3])),
          node_features={'feat': tensor([[0.5160, 0.2486],
                                [0.8672, 0.2276],
                                [0.6172, 0.7865],
                                [0.2109, 0.1089],
                                [0.9634, 0.2294],
                                [0.5503, 0.8223]])},
91
          negative_srcs=None,
92
93
94
95
96
97
98
99
          negative_node_pairs=(tensor([[0, 0],
                                      [1, 1],
                                      [1, 1],
                                      [1, 1]]),
                              tensor([[4, 4],
                                      [1, 4],
                                      [0, 1],
                                      [1, 5]])),
100
101
102
103
          negative_dsts=tensor([[0, 0],
                                [3, 0],
                                [5, 3],
                                [3, 4]]),
104
105
          labels=None,
          input_nodes=tensor([5, 3, 1, 2, 0, 4]),
106
          indexes=None,
107
108
          edge_features=[{},
                        {}],
109
          compacted_seeds=None,
110
111
          compacted_node_pairs=(tensor([0, 1, 1, 1]),
                               tensor([2, 3, 3, 1])),
112
113
114
115
116
          compacted_negative_srcs=None,
          compacted_negative_dsts=tensor([[4, 4],
                                          [1, 4],
                                          [0, 1],
                                          [1, 5]]),
117
          blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
118
                 Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)],
119
       )"""
120
121
        ),
        str(
122
123
            """MiniBatch(seeds=None,
          seed_nodes=None,
124
125
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3], dtype=torch.int32),
                                                                         indices=tensor([4, 1, 0], dtype=torch.int32),
126
                                                           ),
127
                                               original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
128
                                               original_edge_ids=None,
129
                                               original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
130
                            ),
131
132
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3], dtype=torch.int32),
                                                                         indices=tensor([4, 4, 0], dtype=torch.int32),
133
                                                           ),
134
                                               original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
135
                                               original_edge_ids=None,
136
                                               original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
137
138
139
                            )],
          positive_node_pairs=(tensor([0, 1, 1, 2]),
                              tensor([0, 0, 1, 1])),
140
141
          node_pairs_with_labels=((tensor([0, 1, 1, 2, 0, 0, 1, 1, 1, 1, 2, 2]), tensor([0, 0, 1, 1, 3, 4, 5, 4, 1, 0, 3, 4])),
                                 tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])),
142
143
144
145
146
          node_pairs=(tensor([3, 4, 4, 0]),
                     tensor([3, 3, 4, 4])),
          node_features={'feat': tensor([[0.8672, 0.2276],
                                [0.5503, 0.8223],
                                [0.9634, 0.2294],
147
                                [0.6172, 0.7865],
148
                                [0.5160, 0.2486],
149
150
                                [0.2109, 0.1089]])},
          negative_srcs=None,
151
152
153
154
155
156
157
158
          negative_node_pairs=(tensor([[0, 0],
                                      [1, 1],
                                      [1, 1],
                                      [2, 2]]),
                              tensor([[3, 4],
                                      [5, 4],
                                      [1, 0],
                                      [3, 4]])),
159
160
161
162
          negative_dsts=tensor([[1, 5],
                                [2, 5],
                                [4, 3],
                                [1, 5]]),
163
          labels=None,
164
          input_nodes=tensor([3, 4, 0, 1, 5, 2]),
165
          indexes=None,
166
167
          edge_features=[{},
                        {}],
168
          compacted_seeds=None,
169
170
          compacted_node_pairs=(tensor([0, 1, 1, 2]),
                               tensor([0, 0, 1, 1])),
171
172
173
174
175
176
177
          compacted_negative_srcs=None,
          compacted_negative_dsts=tensor([[3, 4],
                                          [5, 4],
                                          [1, 0],
                                          [3, 4]]),
          blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
                 Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)],
178
       )"""
179
180
        ),
        str(
181
182
            """MiniBatch(seeds=None,
          seed_nodes=None,
183
184
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([1, 0], dtype=torch.int32),
185
                                                           ),
186
                                               original_row_node_ids=tensor([5, 4, 0, 1]),
187
                                               original_edge_ids=None,
188
                                               original_column_node_ids=tensor([5, 4, 0, 1]),
189
                            ),
190
191
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([1, 0], dtype=torch.int32),
192
                                                           ),
193
                                               original_row_node_ids=tensor([5, 4, 0, 1]),
194
                                               original_edge_ids=None,
195
                                               original_column_node_ids=tensor([5, 4, 0, 1]),
196
197
198
                            )],
          positive_node_pairs=(tensor([0, 1]),
                              tensor([0, 0])),
199
200
          node_pairs_with_labels=((tensor([0, 1, 0, 0, 1, 1]), tensor([0, 0, 2, 1, 2, 3])),
                                 tensor([1., 1., 0., 0., 0., 0.])),
201
202
203
          node_pairs=(tensor([5, 4]),
                     tensor([5, 5])),
          node_features={'feat': tensor([[0.5160, 0.2486],
204
205
206
207
                                [0.5503, 0.8223],
                                [0.9634, 0.2294],
                                [0.6172, 0.7865]])},
          negative_srcs=None,
208
209
210
211
          negative_node_pairs=(tensor([[0, 0],
                                      [1, 1]]),
                              tensor([[2, 1],
                                      [2, 3]])),
212
213
          negative_dsts=tensor([[0, 4],
                                [0, 1]]),
214
          labels=None,
215
          input_nodes=tensor([5, 4, 0, 1]),
216
          indexes=None,
217
218
          edge_features=[{},
                        {}],
219
          compacted_seeds=None,
220
221
          compacted_node_pairs=(tensor([0, 1]),
                               tensor([0, 0])),
222
223
224
225
226
          compacted_negative_srcs=None,
          compacted_negative_dsts=tensor([[2, 1],
                                          [2, 3]]),
          blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2),
                 Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)],
227
       )"""
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        ),
    ]
    for step, data in enumerate(dataloader):
        assert expected[step] == str(data), print(data)


def test_integration_node_classification():
    torch.manual_seed(926)

    indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])
    indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])

    matrix_a = dglsp.from_csc(indptr, indices)
    node_pairs = torch.t(torch.stack(matrix_a.coo()))
    node_feature_data = torch.tensor(
        [
            [0.9634, 0.2294],
            [0.6172, 0.7865],
            [0.2109, 0.1089],
            [0.8672, 0.2276],
            [0.5503, 0.8223],
            [0.5160, 0.2486],
        ]
    )
    edge_feature_data = torch.tensor(
        [
            [0.5123, 0.1709, 0.6150],
            [0.1476, 0.1902, 0.1314],
            [0.2582, 0.5203, 0.6228],
            [0.3708, 0.7631, 0.2683],
            [0.2126, 0.7878, 0.7225],
            [0.7885, 0.3414, 0.5485],
            [0.4088, 0.8200, 0.1851],
            [0.0056, 0.9469, 0.4432],
            [0.8972, 0.7511, 0.3617],
            [0.5773, 0.2199, 0.3366],
        ]
    )

    item_set = gb.ItemSet(node_pairs, names="node_pairs")
268
    graph = gb.fused_csc_sampling_graph(indptr, indices)
269
270
271
272
273
274
275
276
277
278
279
280
281
282

    node_feature = gb.TorchBasedFeature(node_feature_data)
    edge_feature = gb.TorchBasedFeature(edge_feature_data)
    features = {
        ("node", None, "feat"): node_feature,
        ("edge", None, "feat"): edge_feature,
    }
    feature_store = gb.BasicFeatureStore(features)
    datapipe = gb.ItemSampler(item_set, batch_size=4)
    fanouts = torch.LongTensor([1])
    datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)
    datapipe = datapipe.fetch_feature(
        feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
    )
283
    dataloader = gb.DataLoader(
284
285
286
287
        datapipe,
    )
    expected = [
        str(
288
289
            """MiniBatch(seeds=None,
          seed_nodes=None,
290
291
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
                                                                         indices=tensor([4, 1, 0, 1], dtype=torch.int32),
292
293
                                                           ),
                                               original_row_node_ids=tensor([5, 3, 1, 2, 4]),
294
295
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 3, 1, 2]),
296
                            ),
297
298
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
                                                                         indices=tensor([0, 1, 0, 1], dtype=torch.int32),
299
300
                                                           ),
                                               original_row_node_ids=tensor([5, 3, 1, 2]),
301
302
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 3, 1, 2]),
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
                            )],
          positive_node_pairs=(tensor([0, 1, 1, 1]),
                              tensor([2, 3, 3, 1])),
          node_pairs_with_labels=None,
          node_pairs=(tensor([5, 3, 3, 3]),
                     tensor([1, 2, 2, 3])),
          node_features={'feat': tensor([[0.5160, 0.2486],
                                [0.8672, 0.2276],
                                [0.6172, 0.7865],
                                [0.2109, 0.1089],
                                [0.5503, 0.8223]])},
          negative_srcs=None,
          negative_node_pairs=None,
          negative_dsts=None,
          labels=None,
          input_nodes=tensor([5, 3, 1, 2, 4]),
319
          indexes=None,
320
321
          edge_features=[{},
                        {}],
322
          compacted_seeds=None,
323
324
325
326
327
328
329
          compacted_node_pairs=(tensor([0, 1, 1, 1]),
                               tensor([2, 3, 3, 1])),
          compacted_negative_srcs=None,
          compacted_negative_dsts=None,
          blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
                 Block(num_src_nodes=4, num_dst_nodes=4, num_edges=4)],
       )"""
330
331
        ),
        str(
332
333
            """MiniBatch(seeds=None,
          seed_nodes=None,
334
335
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
                                                                         indices=tensor([0, 2], dtype=torch.int32),
336
337
                                                           ),
                                               original_row_node_ids=tensor([3, 4, 0]),
338
339
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([3, 4, 0]),
340
                            ),
341
342
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
                                                                         indices=tensor([0, 2], dtype=torch.int32),
343
344
                                                           ),
                                               original_row_node_ids=tensor([3, 4, 0]),
345
346
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([3, 4, 0]),
347
348
349
350
351
352
353
354
355
356
357
358
359
360
                            )],
          positive_node_pairs=(tensor([0, 1, 1, 2]),
                              tensor([0, 0, 1, 1])),
          node_pairs_with_labels=None,
          node_pairs=(tensor([3, 4, 4, 0]),
                     tensor([3, 3, 4, 4])),
          node_features={'feat': tensor([[0.8672, 0.2276],
                                [0.5503, 0.8223],
                                [0.9634, 0.2294]])},
          negative_srcs=None,
          negative_node_pairs=None,
          negative_dsts=None,
          labels=None,
          input_nodes=tensor([3, 4, 0]),
361
          indexes=None,
362
363
          edge_features=[{},
                        {}],
364
          compacted_seeds=None,
365
366
367
368
369
370
371
          compacted_node_pairs=(tensor([0, 1, 1, 2]),
                               tensor([0, 0, 1, 1])),
          compacted_negative_srcs=None,
          compacted_negative_dsts=None,
          blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2),
                 Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2)],
       )"""
372
373
        ),
        str(
374
375
            """MiniBatch(seeds=None,
          seed_nodes=None,
376
377
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([0, 2], dtype=torch.int32),
378
379
                                                           ),
                                               original_row_node_ids=tensor([5, 4, 0]),
380
381
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 4]),
382
                            ),
383
384
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([1, 1], dtype=torch.int32),
385
386
                                                           ),
                                               original_row_node_ids=tensor([5, 4]),
387
388
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 4]),
389
390
391
392
393
394
395
396
397
398
399
400
401
402
                            )],
          positive_node_pairs=(tensor([0, 1]),
                              tensor([0, 0])),
          node_pairs_with_labels=None,
          node_pairs=(tensor([5, 4]),
                     tensor([5, 5])),
          node_features={'feat': tensor([[0.5160, 0.2486],
                                [0.5503, 0.8223],
                                [0.9634, 0.2294]])},
          negative_srcs=None,
          negative_node_pairs=None,
          negative_dsts=None,
          labels=None,
          input_nodes=tensor([5, 4, 0]),
403
          indexes=None,
404
405
          edge_features=[{},
                        {}],
406
          compacted_seeds=None,
407
408
409
410
411
412
413
          compacted_node_pairs=(tensor([0, 1]),
                               tensor([0, 0])),
          compacted_negative_srcs=None,
          compacted_negative_dsts=None,
          blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
                 Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
       )"""
414
415
416
417
        ),
    ]
    for step, data in enumerate(dataloader):
        assert expected[step] == str(data), print(data)