test_integration.py 18.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import dgl
import dgl.graphbolt as gb
import dgl.sparse as dglsp
import torch


def test_integration_link_prediction():
    torch.manual_seed(926)

    indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])
    indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])

    matrix_a = dglsp.from_csc(indptr, indices)
    node_pairs = torch.t(torch.stack(matrix_a.coo()))
    node_feature_data = torch.tensor(
        [
            [0.9634, 0.2294],
            [0.6172, 0.7865],
            [0.2109, 0.1089],
            [0.8672, 0.2276],
            [0.5503, 0.8223],
            [0.5160, 0.2486],
        ]
    )
    edge_feature_data = torch.tensor(
        [
            [0.5123, 0.1709, 0.6150],
            [0.1476, 0.1902, 0.1314],
            [0.2582, 0.5203, 0.6228],
            [0.3708, 0.7631, 0.2683],
            [0.2126, 0.7878, 0.7225],
            [0.7885, 0.3414, 0.5485],
            [0.4088, 0.8200, 0.1851],
            [0.0056, 0.9469, 0.4432],
            [0.8972, 0.7511, 0.3617],
            [0.5773, 0.2199, 0.3366],
        ]
    )

    item_set = gb.ItemSet(node_pairs, names="node_pairs")
41
    graph = gb.fused_csc_sampling_graph(indptr, indices)
42
43
44
45
46
47
48
49
50

    node_feature = gb.TorchBasedFeature(node_feature_data)
    edge_feature = gb.TorchBasedFeature(edge_feature_data)
    features = {
        ("node", None, "feat"): node_feature,
        ("edge", None, "feat"): edge_feature,
    }
    feature_store = gb.BasicFeatureStore(features)
    datapipe = gb.ItemSampler(item_set, batch_size=4)
51
    datapipe = datapipe.sample_uniform_negative(graph, 2)
52
53
54
55
56
57
    fanouts = torch.LongTensor([1])
    datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(
        feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
    )
58
    dataloader = gb.DataLoader(
59
60
61
62
        datapipe,
    )
    expected = [
        str(
63
64
            """MiniBatch(seeds=None,
          seed_nodes=None,
65
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
66
                                                                         indices=tensor([0, 4]),
67
68
                                                           ),
                                               original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
69
70
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
71
                            ),
72
73
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
                                                                         indices=tensor([5, 4]),
74
75
                                                           ),
                                               original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
76
                                               original_edge_ids=None,
77
                                               original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
78
79
80
                            )],
          positive_node_pairs=(tensor([0, 1, 1, 1]),
                              tensor([2, 3, 3, 1])),
81
82
          node_pairs_with_labels=((tensor([0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1]), tensor([2, 3, 3, 1, 4, 4, 1, 4, 0, 1, 1, 5])),
                                 tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])),
83
84
85
86
87
88
89
90
          node_pairs=(tensor([5, 3, 3, 3]),
                     tensor([1, 2, 2, 3])),
          node_features={'feat': tensor([[0.5160, 0.2486],
                                [0.8672, 0.2276],
                                [0.6172, 0.7865],
                                [0.2109, 0.1089],
                                [0.9634, 0.2294],
                                [0.5503, 0.8223]])},
91
          negative_srcs=None,
92
93
94
95
96
97
98
99
          negative_node_pairs=(tensor([[0, 0],
                                      [1, 1],
                                      [1, 1],
                                      [1, 1]]),
                              tensor([[4, 4],
                                      [1, 4],
                                      [0, 1],
                                      [1, 5]])),
100
101
102
103
          negative_dsts=tensor([[0, 0],
                                [3, 0],
                                [5, 3],
                                [3, 4]]),
104
105
          labels=None,
          input_nodes=tensor([5, 3, 1, 2, 0, 4]),
106
          indexes=None,
107
108
109
110
          edge_features=[{},
                        {}],
          compacted_node_pairs=(tensor([0, 1, 1, 1]),
                               tensor([2, 3, 3, 1])),
111
112
113
114
115
          compacted_negative_srcs=None,
          compacted_negative_dsts=tensor([[4, 4],
                                          [1, 4],
                                          [0, 1],
                                          [1, 5]]),
116
          blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
117
                 Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)],
118
       )"""
119
120
        ),
        str(
121
122
            """MiniBatch(seeds=None,
          seed_nodes=None,
123
124
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]),
                                                                         indices=tensor([4, 1, 0]),
125
                                                           ),
126
                                               original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
127
                                               original_edge_ids=None,
128
                                               original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
129
                            ),
130
131
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]),
                                                                         indices=tensor([4, 4, 0]),
132
                                                           ),
133
                                               original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
134
                                               original_edge_ids=None,
135
                                               original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
136
137
138
                            )],
          positive_node_pairs=(tensor([0, 1, 1, 2]),
                              tensor([0, 0, 1, 1])),
139
140
          node_pairs_with_labels=((tensor([0, 1, 1, 2, 0, 0, 1, 1, 1, 1, 2, 2]), tensor([0, 0, 1, 1, 3, 4, 5, 4, 1, 0, 3, 4])),
                                 tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])),
141
142
143
144
145
          node_pairs=(tensor([3, 4, 4, 0]),
                     tensor([3, 3, 4, 4])),
          node_features={'feat': tensor([[0.8672, 0.2276],
                                [0.5503, 0.8223],
                                [0.9634, 0.2294],
146
                                [0.6172, 0.7865],
147
                                [0.5160, 0.2486],
148
149
                                [0.2109, 0.1089]])},
          negative_srcs=None,
150
151
152
153
154
155
156
157
          negative_node_pairs=(tensor([[0, 0],
                                      [1, 1],
                                      [1, 1],
                                      [2, 2]]),
                              tensor([[3, 4],
                                      [5, 4],
                                      [1, 0],
                                      [3, 4]])),
158
159
160
161
          negative_dsts=tensor([[1, 5],
                                [2, 5],
                                [4, 3],
                                [1, 5]]),
162
          labels=None,
163
          input_nodes=tensor([3, 4, 0, 1, 5, 2]),
164
          indexes=None,
165
166
167
168
          edge_features=[{},
                        {}],
          compacted_node_pairs=(tensor([0, 1, 1, 2]),
                               tensor([0, 0, 1, 1])),
169
170
171
172
173
174
175
          compacted_negative_srcs=None,
          compacted_negative_dsts=tensor([[3, 4],
                                          [5, 4],
                                          [1, 0],
                                          [3, 4]]),
          blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
                 Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)],
176
       )"""
177
178
        ),
        str(
179
180
            """MiniBatch(seeds=None,
          seed_nodes=None,
181
182
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]),
                                                                         indices=tensor([1, 0]),
183
                                                           ),
184
                                               original_row_node_ids=tensor([5, 4, 0, 1]),
185
                                               original_edge_ids=None,
186
                                               original_column_node_ids=tensor([5, 4, 0, 1]),
187
                            ),
188
189
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]),
                                                                         indices=tensor([1, 0]),
190
                                                           ),
191
                                               original_row_node_ids=tensor([5, 4, 0, 1]),
192
                                               original_edge_ids=None,
193
                                               original_column_node_ids=tensor([5, 4, 0, 1]),
194
195
196
                            )],
          positive_node_pairs=(tensor([0, 1]),
                              tensor([0, 0])),
197
198
          node_pairs_with_labels=((tensor([0, 1, 0, 0, 1, 1]), tensor([0, 0, 2, 1, 2, 3])),
                                 tensor([1., 1., 0., 0., 0., 0.])),
199
200
201
          node_pairs=(tensor([5, 4]),
                     tensor([5, 5])),
          node_features={'feat': tensor([[0.5160, 0.2486],
202
203
204
205
                                [0.5503, 0.8223],
                                [0.9634, 0.2294],
                                [0.6172, 0.7865]])},
          negative_srcs=None,
206
207
208
209
          negative_node_pairs=(tensor([[0, 0],
                                      [1, 1]]),
                              tensor([[2, 1],
                                      [2, 3]])),
210
211
          negative_dsts=tensor([[0, 4],
                                [0, 1]]),
212
          labels=None,
213
          input_nodes=tensor([5, 4, 0, 1]),
214
          indexes=None,
215
216
217
218
          edge_features=[{},
                        {}],
          compacted_node_pairs=(tensor([0, 1]),
                               tensor([0, 0])),
219
220
221
222
223
          compacted_negative_srcs=None,
          compacted_negative_dsts=tensor([[2, 1],
                                          [2, 3]]),
          blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2),
                 Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)],
224
       )"""
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        ),
    ]
    for step, data in enumerate(dataloader):
        assert expected[step] == str(data), print(data)


def test_integration_node_classification():
    torch.manual_seed(926)

    indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])
    indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])

    matrix_a = dglsp.from_csc(indptr, indices)
    node_pairs = torch.t(torch.stack(matrix_a.coo()))
    node_feature_data = torch.tensor(
        [
            [0.9634, 0.2294],
            [0.6172, 0.7865],
            [0.2109, 0.1089],
            [0.8672, 0.2276],
            [0.5503, 0.8223],
            [0.5160, 0.2486],
        ]
    )
    edge_feature_data = torch.tensor(
        [
            [0.5123, 0.1709, 0.6150],
            [0.1476, 0.1902, 0.1314],
            [0.2582, 0.5203, 0.6228],
            [0.3708, 0.7631, 0.2683],
            [0.2126, 0.7878, 0.7225],
            [0.7885, 0.3414, 0.5485],
            [0.4088, 0.8200, 0.1851],
            [0.0056, 0.9469, 0.4432],
            [0.8972, 0.7511, 0.3617],
            [0.5773, 0.2199, 0.3366],
        ]
    )

    item_set = gb.ItemSet(node_pairs, names="node_pairs")
265
    graph = gb.fused_csc_sampling_graph(indptr, indices)
266
267
268
269
270
271
272
273
274
275
276
277
278
279

    node_feature = gb.TorchBasedFeature(node_feature_data)
    edge_feature = gb.TorchBasedFeature(edge_feature_data)
    features = {
        ("node", None, "feat"): node_feature,
        ("edge", None, "feat"): edge_feature,
    }
    feature_store = gb.BasicFeatureStore(features)
    datapipe = gb.ItemSampler(item_set, batch_size=4)
    fanouts = torch.LongTensor([1])
    datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)
    datapipe = datapipe.fetch_feature(
        feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
    )
280
    dataloader = gb.DataLoader(
281
282
283
284
        datapipe,
    )
    expected = [
        str(
285
286
            """MiniBatch(seeds=None,
          seed_nodes=None,
287
288
289
290
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
                                                                         indices=tensor([4, 1, 0, 1]),
                                                           ),
                                               original_row_node_ids=tensor([5, 3, 1, 2, 4]),
291
292
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 3, 1, 2]),
293
                            ),
294
295
296
297
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
                                                                         indices=tensor([0, 1, 0, 1]),
                                                           ),
                                               original_row_node_ids=tensor([5, 3, 1, 2]),
298
299
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 3, 1, 2]),
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
                            )],
          positive_node_pairs=(tensor([0, 1, 1, 1]),
                              tensor([2, 3, 3, 1])),
          node_pairs_with_labels=None,
          node_pairs=(tensor([5, 3, 3, 3]),
                     tensor([1, 2, 2, 3])),
          node_features={'feat': tensor([[0.5160, 0.2486],
                                [0.8672, 0.2276],
                                [0.6172, 0.7865],
                                [0.2109, 0.1089],
                                [0.5503, 0.8223]])},
          negative_srcs=None,
          negative_node_pairs=None,
          negative_dsts=None,
          labels=None,
          input_nodes=tensor([5, 3, 1, 2, 4]),
316
          indexes=None,
317
318
319
320
321
322
323
324
325
          edge_features=[{},
                        {}],
          compacted_node_pairs=(tensor([0, 1, 1, 1]),
                               tensor([2, 3, 3, 1])),
          compacted_negative_srcs=None,
          compacted_negative_dsts=None,
          blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
                 Block(num_src_nodes=4, num_dst_nodes=4, num_edges=4)],
       )"""
326
327
        ),
        str(
328
329
            """MiniBatch(seeds=None,
          seed_nodes=None,
330
331
332
333
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
                                                                         indices=tensor([0, 2]),
                                                           ),
                                               original_row_node_ids=tensor([3, 4, 0]),
334
335
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([3, 4, 0]),
336
                            ),
337
338
339
340
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
                                                                         indices=tensor([0, 2]),
                                                           ),
                                               original_row_node_ids=tensor([3, 4, 0]),
341
342
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([3, 4, 0]),
343
344
345
346
347
348
349
350
351
352
353
354
355
356
                            )],
          positive_node_pairs=(tensor([0, 1, 1, 2]),
                              tensor([0, 0, 1, 1])),
          node_pairs_with_labels=None,
          node_pairs=(tensor([3, 4, 4, 0]),
                     tensor([3, 3, 4, 4])),
          node_features={'feat': tensor([[0.8672, 0.2276],
                                [0.5503, 0.8223],
                                [0.9634, 0.2294]])},
          negative_srcs=None,
          negative_node_pairs=None,
          negative_dsts=None,
          labels=None,
          input_nodes=tensor([3, 4, 0]),
357
          indexes=None,
358
359
360
361
362
363
364
365
366
          edge_features=[{},
                        {}],
          compacted_node_pairs=(tensor([0, 1, 1, 2]),
                               tensor([0, 0, 1, 1])),
          compacted_negative_srcs=None,
          compacted_negative_dsts=None,
          blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2),
                 Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2)],
       )"""
367
368
        ),
        str(
369
370
            """MiniBatch(seeds=None,
          seed_nodes=None,
371
372
373
374
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]),
                                                                         indices=tensor([0, 2]),
                                                           ),
                                               original_row_node_ids=tensor([5, 4, 0]),
375
376
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 4]),
377
                            ),
378
379
380
381
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]),
                                                                         indices=tensor([1, 1]),
                                                           ),
                                               original_row_node_ids=tensor([5, 4]),
382
383
                                               original_edge_ids=None,
                                               original_column_node_ids=tensor([5, 4]),
384
385
386
387
388
389
390
391
392
393
394
395
396
397
                            )],
          positive_node_pairs=(tensor([0, 1]),
                              tensor([0, 0])),
          node_pairs_with_labels=None,
          node_pairs=(tensor([5, 4]),
                     tensor([5, 5])),
          node_features={'feat': tensor([[0.5160, 0.2486],
                                [0.5503, 0.8223],
                                [0.9634, 0.2294]])},
          negative_srcs=None,
          negative_node_pairs=None,
          negative_dsts=None,
          labels=None,
          input_nodes=tensor([5, 4, 0]),
398
          indexes=None,
399
400
401
402
403
404
405
406
407
          edge_features=[{},
                        {}],
          compacted_node_pairs=(tensor([0, 1]),
                               tensor([0, 0])),
          compacted_negative_srcs=None,
          compacted_negative_dsts=None,
          blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
                 Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
       )"""
408
409
410
411
        ),
    ]
    for step, data in enumerate(dataloader):
        assert expected[step] == str(data), print(data)