test_minibatch.py 37.7 KB
Newer Older
1
2
import dgl
import dgl.graphbolt as gb
peizhou001's avatar
peizhou001 committed
3
import pytest
4
5
6
import torch


peizhou001's avatar
peizhou001 committed
7
8
9
10
relation = "A:r:B"
reverse_relation = "B:rr:A"


11
12
13
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
14
15
    csc_formats = [
        gb.CSCFormatBase(
16
17
            indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),
            indices=torch.tensor([0, 1, 2, 2, 1, 2], dtype=indices_dtype),
18
        ),
19
        gb.CSCFormatBase(
20
21
            indptr=torch.tensor([0, 2, 3], dtype=indptr_dtype),
            indices=torch.tensor([1, 2, 0], dtype=indices_dtype),
22
23
        ),
    ]
24
    original_column_node_ids = [
25
26
27
        torch.tensor([10, 11, 12, 13]),
        torch.tensor([10, 11]),
    ]
28
    original_row_node_ids = [
29
30
31
        torch.tensor([10, 11, 12, 13]),
        torch.tensor([10, 11, 12]),
    ]
32
    original_edge_ids = [
33
34
35
        torch.tensor([19, 20, 21, 22, 25, 30]),
        torch.tensor([10, 15, 17]),
    ]
36
    node_features = {"x": torch.tensor([5, 0, 2, 1])}
37
    edge_features = [
38
39
        {"x": torch.tensor([9, 0, 1, 1, 7, 4])},
        {"x": torch.tensor([0, 2, 2])},
40
41
42
43
    ]
    subgraphs = []
    for i in range(2):
        subgraphs.append(
44
            gb.SampledSubgraphImpl(
45
                sampled_csc=csc_formats[i],
46
47
48
                original_column_node_ids=original_column_node_ids[i],
                original_row_node_ids=original_row_node_ids[i],
                original_edge_ids=original_edge_ids[i],
49
50
51
52
53
            )
        )
    negative_srcs = torch.tensor([[8], [1], [6]])
    negative_dsts = torch.tensor([[2], [8], [8]])
    input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4])
54
55
56
    compacted_csc_formats = gb.CSCFormatBase(
        indptr=torch.tensor([0, 2, 3]), indices=torch.tensor([3, 4, 5])
    )
57
58
    compacted_negative_srcs = torch.tensor([[0], [1], [2]])
    compacted_negative_dsts = torch.tensor([[6], [0], [0]])
59
60
61
62
    labels = torch.tensor([0.0, 1.0, 2.0])
    # Test minibatch without data.
    minibatch = gb.MiniBatch()
    expect_result = str(
63
64
        """MiniBatch(seeds=None,
          seed_nodes=None,
65
          sampled_subgraphs=None,
66
67
          positive_node_pairs=None,
          node_pairs_with_labels=None,
68
69
70
          node_pairs=None,
          node_features=None,
          negative_srcs=None,
71
          negative_node_pairs=None,
72
73
74
          negative_dsts=None,
          labels=None,
          input_nodes=None,
75
          indexes=None,
76
          edge_features=None,
77
          compacted_seeds=None,
78
79
80
          compacted_node_pairs=None,
          compacted_negative_srcs=None,
          compacted_negative_dsts=None,
81
          blocks=None,
82
83
84
       )"""
    )
    result = str(minibatch)
85
    assert result == expect_result, print(expect_result, result)
86
87
    # Test minibatch with all attributes.
    minibatch = gb.MiniBatch(
88
        node_pairs=csc_formats,
89
90
91
92
93
94
        sampled_subgraphs=subgraphs,
        labels=labels,
        node_features=node_features,
        edge_features=edge_features,
        negative_srcs=negative_srcs,
        negative_dsts=negative_dsts,
95
        compacted_node_pairs=compacted_csc_formats,
96
97
98
99
100
        input_nodes=input_nodes,
        compacted_negative_srcs=compacted_negative_srcs,
        compacted_negative_dsts=compacted_negative_dsts,
    )
    expect_result = str(
101
102
        """MiniBatch(seeds=None,
          seed_nodes=None,
103
104
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32),
                                                                         indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32),
105
106
                                                           ),
                                               original_row_node_ids=tensor([10, 11, 12, 13]),
107
108
109
                                               original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
                                               original_column_node_ids=tensor([10, 11, 12, 13]),
                            ),
110
111
                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 2, 3], dtype=torch.int32),
                                                                         indices=tensor([1, 2, 0], dtype=torch.int32),
112
113
                                                           ),
                                               original_row_node_ids=tensor([10, 11, 12]),
114
115
116
                                               original_edge_ids=tensor([10, 15, 17]),
                                               original_column_node_ids=tensor([10, 11]),
                            )],
117
118
119
120
121
122
123
          positive_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
                                            indices=tensor([3, 4, 5]),
                              ),
          node_pairs_with_labels=(CSCFormatBase(indptr=tensor([0, 2, 3]),
                                               indices=tensor([3, 4, 5]),
                                 ),
                                 tensor([0., 1., 2.])),
124
125
          node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32),
                                   indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32),
126
                     ),
127
128
                     CSCFormatBase(indptr=tensor([0, 2, 3], dtype=torch.int32),
                                   indices=tensor([1, 2, 0], dtype=torch.int32),
129
130
                     )],
          node_features={'x': tensor([5, 0, 2, 1])},
131
132
133
          negative_srcs=tensor([[8],
                                [1],
                                [6]]),
134
135
136
137
138
139
          negative_node_pairs=(tensor([[0],
                                      [1],
                                      [2]]),
                              tensor([[6],
                                      [0],
                                      [0]])),
140
141
142
143
144
          negative_dsts=tensor([[2],
                                [8],
                                [8]]),
          labels=tensor([0., 1., 2.]),
          input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]),
145
          indexes=None,
146
147
          edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])},
                        {'x': tensor([0, 2, 2])}],
148
          compacted_seeds=None,
149
150
151
          compacted_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
                                             indices=tensor([3, 4, 5]),
                               ),
152
153
154
155
156
157
          compacted_negative_srcs=tensor([[0],
                                          [1],
                                          [2]]),
          compacted_negative_dsts=tensor([[6],
                                          [0],
                                          [0]]),
158
159
          blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6),
                 Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)],
160
161
162
163
       )"""
    )
    result = str(minibatch)
    assert result == expect_result, print(expect_result, result)
peizhou001's avatar
peizhou001 committed
164
165


166
167
168
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
169
170
171
    csc_formats = [
        {
            relation: gb.CSCFormatBase(
172
173
                indptr=torch.tensor([0, 1, 2, 3], dtype=indptr_dtype),
                indices=torch.tensor([0, 1, 1], dtype=indices_dtype),
174
175
            ),
            reverse_relation: gb.CSCFormatBase(
176
177
                indptr=torch.tensor([0, 0, 0, 1, 2], dtype=indptr_dtype),
                indices=torch.tensor([1, 0], dtype=indices_dtype),
178
179
180
181
            ),
        },
        {
            relation: gb.CSCFormatBase(
182
183
                indptr=torch.tensor([0, 1, 2], dtype=indptr_dtype),
                indices=torch.tensor([1, 0], dtype=indices_dtype),
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
            )
        },
    ]
    original_column_node_ids = [
        {"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
        {"B": torch.tensor([10, 11])},
    ]
    original_row_node_ids = [
        {
            "A": torch.tensor([5, 7, 9, 11]),
            "B": torch.tensor([10, 11, 12]),
        },
        {
            "A": torch.tensor([5, 7]),
            "B": torch.tensor([10, 11]),
        },
    ]
    original_edge_ids = [
        {
            relation: torch.tensor([19, 20, 21]),
            reverse_relation: torch.tensor([23, 26]),
        },
        {relation: torch.tensor([10, 12])},
    ]
    node_features = {
        ("A", "x"): torch.tensor([6, 4, 0, 1]),
    }
    edge_features = [
        {(relation, "x"): torch.tensor([4, 2, 4])},
        {(relation, "x"): torch.tensor([0, 6])},
    ]
    subgraphs = []
    for i in range(2):
        subgraphs.append(
            gb.SampledSubgraphImpl(
219
                sampled_csc=csc_formats[i],
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
                original_column_node_ids=original_column_node_ids[i],
                original_row_node_ids=original_row_node_ids[i],
                original_edge_ids=original_edge_ids[i],
            )
        )
    negative_srcs = {"B": torch.tensor([[8], [1], [6]])}
    negative_dsts = {"B": torch.tensor([[2], [8], [8]])}
    compacted_csc_formats = {
        relation: gb.CSCFormatBase(
            indptr=torch.tensor([0, 1, 2, 3]), indices=torch.tensor([3, 4, 5])
        ),
        reverse_relation: gb.CSCFormatBase(
            indptr=torch.tensor([0, 0, 0, 1, 2]), indices=torch.tensor([0, 1])
        ),
    }
    compacted_negative_srcs = {relation: torch.tensor([[0], [1], [2]])}
    compacted_negative_dsts = {relation: torch.tensor([[6], [0], [0]])}
237
    # Test minibatch with all attributes.
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    minibatch = gb.MiniBatch(
        seed_nodes={"B": torch.tensor([10, 15])},
        node_pairs=csc_formats,
        sampled_subgraphs=subgraphs,
        node_features=node_features,
        edge_features=edge_features,
        labels={"B": torch.tensor([2, 5])},
        negative_srcs=negative_srcs,
        negative_dsts=negative_dsts,
        compacted_node_pairs=compacted_csc_formats,
        input_nodes={
            "A": torch.tensor([5, 7, 9, 11]),
            "B": torch.tensor([10, 11, 12]),
        },
        compacted_negative_srcs=compacted_negative_srcs,
        compacted_negative_dsts=compacted_negative_dsts,
    )
    expect_result = str(
256
257
        """MiniBatch(seeds=None,
          seed_nodes={'B': tensor([10, 15])},
258
259
260
261
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
                                                                         indices=tensor([0, 1, 1], dtype=torch.int32),
                                                           ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([1, 0], dtype=torch.int32),
262
263
                                                           )},
                                               original_row_node_ids={'A': tensor([ 5,  7,  9, 11]), 'B': tensor([10, 11, 12])},
264
265
266
                                               original_edge_ids={'A:r:B': tensor([19, 20, 21]), 'B:rr:A': tensor([23, 26])},
                                               original_column_node_ids={'B': tensor([10, 11, 12]), 'A': tensor([ 5,  7,  9, 11])},
                            ),
267
268
                            SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
                                                                         indices=tensor([1, 0], dtype=torch.int32),
269
270
                                                           )},
                                               original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])},
271
272
273
                                               original_edge_ids={'A:r:B': tensor([10, 12])},
                                               original_column_node_ids={'B': tensor([10, 11])},
                            )],
274
275
276
277
278
279
280
281
282
283
284
          positive_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
                                            indices=tensor([3, 4, 5]),
                              ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
                                            indices=tensor([0, 1]),
                              )},
          node_pairs_with_labels=({'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
                                               indices=tensor([3, 4, 5]),
                                 ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
                                               indices=tensor([0, 1]),
                                 )},
                                 {'B': tensor([2, 5])}),
285
286
287
288
          node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
                                   indices=tensor([0, 1, 1], dtype=torch.int32),
                     ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32),
                                   indices=tensor([1, 0], dtype=torch.int32),
289
                     )},
290
291
                     {'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
                                   indices=tensor([1, 0], dtype=torch.int32),
292
293
294
295
296
                     )}],
          node_features={('A', 'x'): tensor([6, 4, 0, 1])},
          negative_srcs={'B': tensor([[8],
                                [1],
                                [6]])},
297
298
299
300
301
          negative_node_pairs={'A:r:B': (tensor([[0],
                                      [1],
                                      [2]]), tensor([[6],
                                      [0],
                                      [0]]))},
302
303
304
305
306
          negative_dsts={'B': tensor([[2],
                                [8],
                                [8]])},
          labels={'B': tensor([2, 5])},
          input_nodes={'A': tensor([ 5,  7,  9, 11]), 'B': tensor([10, 11, 12])},
307
          indexes=None,
308
309
          edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])},
                        {('A:r:B', 'x'): tensor([0, 6])}],
310
          compacted_seeds=None,
311
312
313
314
315
316
317
318
319
320
321
          compacted_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
                                             indices=tensor([3, 4, 5]),
                               ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
                                             indices=tensor([0, 1]),
                               )},
          compacted_negative_srcs={'A:r:B': tensor([[0],
                                          [1],
                                          [2]])},
          compacted_negative_dsts={'A:r:B': tensor([[6],
                                          [0],
                                          [0]])},
322
323
324
325
326
327
328
329
          blocks=[Block(num_src_nodes={'A': 4, 'B': 3},
                       num_dst_nodes={'A': 4, 'B': 3},
                       num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2},
                       metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')]),
                 Block(num_src_nodes={'A': 2, 'B': 2},
                       num_dst_nodes={'B': 2},
                       num_edges={('A', 'r', 'B'): 2},
                       metagraph=[('A', 'B', 'r')])],
330
331
332
       )"""
    )
    result = str(minibatch)
333
    assert result == expect_result, print(expect_result, result)
334
335


336
337
338
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype):
339
340
341
342
343
344
345
346
347
348
    node_pairs = [
        (
            torch.tensor([0, 1, 2, 2, 2, 1]),
            torch.tensor([0, 1, 1, 2, 3, 2]),
        ),
        (
            torch.tensor([0, 1, 2]),
            torch.tensor([1, 0, 0]),
        ),
    ]
349
350
    csc_formats = [
        gb.CSCFormatBase(
351
352
            indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),
            indices=torch.tensor([0, 1, 2, 2, 1, 2], dtype=indices_dtype),
353
354
        ),
        gb.CSCFormatBase(
355
356
            indptr=torch.tensor([0, 1, 3], dtype=indptr_dtype),
            indices=torch.tensor([0, 1, 2], dtype=indices_dtype),
357
358
        ),
    ]
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    original_column_node_ids = [
        torch.tensor([10, 11, 12, 13]),
        torch.tensor([10, 11]),
    ]
    original_row_node_ids = [
        torch.tensor([10, 11, 12, 13]),
        torch.tensor([10, 11, 12]),
    ]
    original_edge_ids = [
        torch.tensor([19, 20, 21, 22, 25, 30]),
        torch.tensor([10, 15, 17]),
    ]
    node_features = {"x": torch.tensor([7, 6, 2, 2])}
    edge_features = [
        {"x": torch.tensor([[8], [1], [6]])},
        {"x": torch.tensor([[2], [8], [8]])},
    ]
    subgraphs = []
    for i in range(2):
        subgraphs.append(
379
380
            gb.SampledSubgraphImpl(
                sampled_csc=csc_formats[i],
381
382
383
384
385
386
387
388
389
                original_column_node_ids=original_column_node_ids[i],
                original_row_node_ids=original_row_node_ids[i],
                original_edge_ids=original_edge_ids[i],
            )
        )
    negative_srcs = torch.tensor([[8], [1], [6]])
    negative_dsts = torch.tensor([[2], [8], [8]])
    input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4])
    compacted_node_pairs = (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5]))
390
391
    compacted_negative_srcs = torch.tensor([[0], [1], [2]])
    compacted_negative_dsts = torch.tensor([[6], [0], [0]])
392
    labels = torch.tensor([0.0, 1.0, 2.0])
393
    # Test minibatch with all attributes.
394
395
396
397
398
399
400
401
402
403
404
405
406
    minibatch = gb.MiniBatch(
        node_pairs=node_pairs,
        sampled_subgraphs=subgraphs,
        labels=labels,
        node_features=node_features,
        edge_features=edge_features,
        negative_srcs=negative_srcs,
        negative_dsts=negative_dsts,
        compacted_node_pairs=compacted_node_pairs,
        input_nodes=input_nodes,
        compacted_negative_srcs=compacted_negative_srcs,
        compacted_negative_dsts=compacted_negative_dsts,
    )
407
    dgl_blocks = minibatch.blocks
408
    expect_result = str(
409
        """[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6), Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)]"""
410
    )
411
    result = str(dgl_blocks)
412
413
414
    assert result == expect_result, print(result)


415
def test_get_dgl_blocks_hetero():
416
417
418
419
420
421
422
    node_pairs = [
        {
            relation: (torch.tensor([0, 1, 1]), torch.tensor([0, 1, 2])),
            reverse_relation: (torch.tensor([1, 0]), torch.tensor([2, 3])),
        },
        {relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
    ]
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    csc_formats = [
        {
            relation: gb.CSCFormatBase(
                indptr=torch.tensor([0, 1, 2, 3]),
                indices=torch.tensor([0, 1, 1]),
            ),
            reverse_relation: gb.CSCFormatBase(
                indptr=torch.tensor([0, 0, 0, 1, 2]),
                indices=torch.tensor([1, 0]),
            ),
        },
        {
            relation: gb.CSCFormatBase(
                indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0])
            )
        },
    ]
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    original_column_node_ids = [
        {"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
        {"B": torch.tensor([10, 11])},
    ]
    original_row_node_ids = [
        {
            "A": torch.tensor([5, 7, 9, 11]),
            "B": torch.tensor([10, 11, 12]),
        },
        {
            "A": torch.tensor([5, 7]),
            "B": torch.tensor([10, 11]),
        },
    ]
    original_edge_ids = [
        {
            relation: torch.tensor([19, 20, 21]),
            reverse_relation: torch.tensor([23, 26]),
        },
        {relation: torch.tensor([10, 12])},
    ]
    node_features = {
        ("A", "x"): torch.tensor([6, 4, 0, 1]),
    }
    edge_features = [
        {(relation, "x"): torch.tensor([4, 2, 4])},
        {(relation, "x"): torch.tensor([0, 6])},
    ]
    subgraphs = []
    for i in range(2):
        subgraphs.append(
471
472
            gb.SampledSubgraphImpl(
                sampled_csc=csc_formats[i],
473
474
475
476
477
478
479
480
481
482
483
484
485
                original_column_node_ids=original_column_node_ids[i],
                original_row_node_ids=original_row_node_ids[i],
                original_edge_ids=original_edge_ids[i],
            )
        )
    negative_srcs = {"B": torch.tensor([[8], [1], [6]])}
    negative_dsts = {"B": torch.tensor([[2], [8], [8]])}
    compacted_node_pairs = {
        relation: (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])),
        reverse_relation: (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])),
    }
    compacted_negative_srcs = {relation: torch.tensor([[0], [1], [2]])}
    compacted_negative_dsts = {relation: torch.tensor([[6], [0], [0]])}
486
    # Test minibatch with all attributes.
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    minibatch = gb.MiniBatch(
        seed_nodes={"B": torch.tensor([10, 15])},
        node_pairs=node_pairs,
        sampled_subgraphs=subgraphs,
        node_features=node_features,
        edge_features=edge_features,
        labels={"B": torch.tensor([2, 5])},
        negative_srcs=negative_srcs,
        negative_dsts=negative_dsts,
        compacted_node_pairs=compacted_node_pairs,
        input_nodes={
            "A": torch.tensor([5, 7, 9, 11]),
            "B": torch.tensor([10, 11, 12]),
        },
        compacted_negative_srcs=compacted_negative_srcs,
        compacted_negative_dsts=compacted_negative_dsts,
    )
504
    dgl_blocks = minibatch.blocks
505
    expect_result = str(
506
507
508
509
510
511
512
        """[Block(num_src_nodes={'A': 4, 'B': 3},
      num_dst_nodes={'A': 4, 'B': 3},
      num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2},
      metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')]), Block(num_src_nodes={'A': 2, 'B': 2},
      num_dst_nodes={'B': 2},
      num_edges={('A', 'r', 'B'): 2},
      metagraph=[('A', 'B', 'r')])]"""
513
    )
514
    result = str(dgl_blocks)
515
516
517
    assert result == expect_result, print(result)


518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
@pytest.mark.parametrize(
    "mode", ["neg_graph", "neg_src", "neg_dst", "edge_classification"]
)
def test_minibatch_node_pairs_with_labels(mode):
    # Arrange
    minibatch = create_homo_minibatch()
    minibatch.compacted_node_pairs = (
        torch.tensor([0, 1]),
        torch.tensor([1, 0]),
    )
    if mode == "neg_graph" or mode == "neg_src":
        minibatch.compacted_negative_srcs = torch.tensor([[0, 0], [1, 1]])
    if mode == "neg_graph" or mode == "neg_dst":
        minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
    if mode == "edge_classification":
        minibatch.labels = torch.tensor([0, 1]).long()
    # Act
    node_pairs, labels = minibatch.node_pairs_with_labels

    # Assert
    if mode == "neg_src":
        expect_node_pairs = (
            torch.tensor([0, 1, 0, 0, 1, 1]),
            torch.tensor([1, 0, 1, 1, 0, 0]),
        )
        expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
    elif mode != "edge_classification":
        expect_node_pairs = (
            torch.tensor([0, 1, 0, 0, 1, 1]),
            torch.tensor([1, 0, 1, 0, 0, 1]),
        )
        expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
    else:
        expect_node_pairs = (
            torch.tensor([0, 1]),
            torch.tensor([1, 0]),
        )
        expect_labels = torch.tensor([0, 1]).long()
    assert torch.equal(node_pairs[0], expect_node_pairs[0])
    assert torch.equal(node_pairs[1], expect_node_pairs[1])
    assert torch.equal(labels, expect_labels)


561
def create_homo_minibatch():
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    csc_formats = [
        gb.CSCFormatBase(
            indptr=torch.tensor([0, 1, 3, 5, 6]),
            indices=torch.tensor([0, 1, 2, 2, 1, 2]),
        ),
        gb.CSCFormatBase(
            indptr=torch.tensor([0, 2, 3]),
            indices=torch.tensor([1, 2, 0]),
        ),
    ]
    original_column_node_ids = [
        torch.tensor([10, 11, 12, 13]),
        torch.tensor([10, 11]),
    ]
    original_row_node_ids = [
        torch.tensor([10, 11, 12, 13]),
        torch.tensor([10, 11, 12]),
    ]
    original_edge_ids = [
        torch.tensor([19, 20, 21, 22, 25, 30]),
        torch.tensor([10, 15, 17]),
    ]
    node_features = {"x": torch.randint(0, 10, (4,))}
    edge_features = [
        {"x": torch.randint(0, 10, (6,))},
        {"x": torch.randint(0, 10, (3,))},
    ]
    subgraphs = []
    for i in range(2):
        subgraphs.append(
            gb.SampledSubgraphImpl(
593
                sampled_csc=csc_formats[i],
594
595
596
597
598
599
600
601
602
603
604
605
606
                original_column_node_ids=original_column_node_ids[i],
                original_row_node_ids=original_row_node_ids[i],
                original_edge_ids=original_edge_ids[i],
            )
        )
    return gb.MiniBatch(
        sampled_subgraphs=subgraphs,
        node_features=node_features,
        edge_features=edge_features,
        input_nodes=torch.tensor([10, 11, 12, 13]),
    )


607
def create_hetero_minibatch():
608
    sampled_csc = [
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
        {
            relation: gb.CSCFormatBase(
                indptr=torch.tensor([0, 1, 2, 3]),
                indices=torch.tensor([0, 1, 1]),
            ),
            reverse_relation: gb.CSCFormatBase(
                indptr=torch.tensor([0, 0, 0, 1, 2]),
                indices=torch.tensor([1, 0]),
            ),
        },
        {
            relation: gb.CSCFormatBase(
                indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0])
            )
        },
    ]
    original_column_node_ids = [
        {"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
        {"B": torch.tensor([10, 11])},
    ]
    original_row_node_ids = [
        {
            "A": torch.tensor([5, 7, 9, 11]),
            "B": torch.tensor([10, 11, 12]),
        },
        {
            "A": torch.tensor([5, 7]),
            "B": torch.tensor([10, 11]),
        },
    ]
    original_edge_ids = [
        {
            relation: torch.tensor([19, 20, 21]),
            reverse_relation: torch.tensor([23, 26]),
        },
        {relation: torch.tensor([10, 12])},
    ]
    node_features = {
        ("A", "x"): torch.randint(0, 10, (4,)),
    }
    edge_features = [
        {(relation, "x"): torch.randint(0, 10, (3,))},
        {(relation, "x"): torch.randint(0, 10, (2,))},
    ]
    subgraphs = []
    for i in range(2):
        subgraphs.append(
            gb.SampledSubgraphImpl(
657
                sampled_csc=sampled_csc[i],
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
                original_column_node_ids=original_column_node_ids[i],
                original_row_node_ids=original_row_node_ids[i],
                original_edge_ids=original_edge_ids[i],
            )
        )
    return gb.MiniBatch(
        sampled_subgraphs=subgraphs,
        node_features=node_features,
        edge_features=edge_features,
        input_nodes={
            "A": torch.tensor([5, 7, 9, 11]),
            "B": torch.tensor([10, 11, 12]),
        },
    )


674
def check_dgl_blocks_hetero(minibatch, blocks):
675
    etype = gb.etype_str_to_tuple(relation)
676
677
    sampled_csc = [
        subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
678
679
680
681
682
683
684
685
686
687
688
689
    ]
    original_edge_ids = [
        subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
    ]
    original_row_node_ids = [
        subgraph.original_row_node_ids
        for subgraph in minibatch.sampled_subgraphs
    ]

    for i, block in enumerate(blocks):
        edges = block.edges(etype=etype)
        dst_ndoes = torch.arange(
690
            0, len(sampled_csc[i][relation].indptr) - 1
691
        ).repeat_interleave(sampled_csc[i][relation].indptr.diff())
692
        assert torch.equal(edges[0], sampled_csc[i][relation].indices)
693
694
695
696
697
698
        assert torch.equal(edges[1], dst_ndoes)
        assert torch.equal(
            block.edges[etype].data[dgl.EID], original_edge_ids[i][relation]
        )
    edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
    dst_ndoes = torch.arange(
699
        0, len(sampled_csc[0][reverse_relation].indptr) - 1
700
    ).repeat_interleave(sampled_csc[0][reverse_relation].indptr.diff())
701
    assert torch.equal(edges[0], sampled_csc[0][reverse_relation].indices)
702
703
704
705
706
707
708
709
710
    assert torch.equal(edges[1], dst_ndoes)
    assert torch.equal(
        blocks[0].srcdata[dgl.NID]["A"], original_row_node_ids[0]["A"]
    )
    assert torch.equal(
        blocks[0].srcdata[dgl.NID]["B"], original_row_node_ids[0]["B"]
    )


711
def check_dgl_blocks_homo(minibatch, blocks):
712
713
    sampled_csc = [
        subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
714
715
716
717
718
719
720
721
722
723
    ]
    original_edge_ids = [
        subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
    ]
    original_row_node_ids = [
        subgraph.original_row_node_ids
        for subgraph in minibatch.sampled_subgraphs
    ]
    for i, block in enumerate(blocks):
        dst_ndoes = torch.arange(
724
            0, len(sampled_csc[i].indptr) - 1
725
        ).repeat_interleave(sampled_csc[i].indptr.diff())
726
        assert torch.equal(block.edges()[0], sampled_csc[i].indices), print(
727
728
729
730
731
732
733
734
735
736
737
            block.edges()
        )
        assert torch.equal(block.edges()[1], dst_ndoes), print(block.edges())
        assert torch.equal(block.edata[dgl.EID], original_edge_ids[i]), print(
            block.edata[dgl.EID]
        )
    assert torch.equal(
        blocks[0].srcdata[dgl.NID], original_row_node_ids[0]
    ), print(blocks[0].srcdata[dgl.NID])


738
def test_dgl_node_classification_without_feature():
739
    # Arrange
740
    minibatch = create_homo_minibatch()
741
742
743
744
    minibatch.node_features = None
    minibatch.labels = None
    minibatch.seed_nodes = torch.tensor([10, 15])
    # Act
745
    dgl_blocks = minibatch.blocks
746
747

    # Assert
748
749
750
    assert len(dgl_blocks) == 2
    assert minibatch.node_features is None
    assert minibatch.labels is None
751
    check_dgl_blocks_homo(minibatch, dgl_blocks)
752
753


754
def test_dgl_node_classification_homo():
755
    # Arrange
756
    minibatch = create_homo_minibatch()
757
758
759
    minibatch.seed_nodes = torch.tensor([10, 15])
    minibatch.labels = torch.tensor([2, 5])
    # Act
760
    dgl_blocks = minibatch.blocks
761
762

    # Assert
763
    assert len(dgl_blocks) == 2
764
    check_dgl_blocks_homo(minibatch, dgl_blocks)
765
766


767
768
def test_dgl_node_classification_hetero():
    minibatch = create_hetero_minibatch()
769
770
    minibatch.labels = {"B": torch.tensor([2, 5])}
    minibatch.seed_nodes = {"B": torch.tensor([10, 15])}
771
772
    # Act
    dgl_blocks = minibatch.blocks
773
774

    # Assert
775
    assert len(dgl_blocks) == 2
776
    check_dgl_blocks_hetero(minibatch, dgl_blocks)
777
778
779


@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
780
def test_dgl_link_predication_homo(mode):
781
    # Arrange
782
    minibatch = create_homo_minibatch()
783
784
785
786
787
788
789
790
791
    minibatch.compacted_node_pairs = (
        torch.tensor([0, 1]),
        torch.tensor([1, 0]),
    )
    if mode == "neg_graph" or mode == "neg_src":
        minibatch.compacted_negative_srcs = torch.tensor([[0, 0], [1, 1]])
    if mode == "neg_graph" or mode == "neg_dst":
        minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
    # Act
792
    dgl_blocks = minibatch.blocks
793
794

    # Assert
795
    assert len(dgl_blocks) == 2
796
    check_dgl_blocks_homo(minibatch, dgl_blocks)
797
798
    if mode == "neg_graph" or mode == "neg_src":
        assert torch.equal(
799
            minibatch.negative_node_pairs[0],
800
            minibatch.compacted_negative_srcs,
801
802
803
        )
    if mode == "neg_graph" or mode == "neg_dst":
        assert torch.equal(
804
            minibatch.negative_node_pairs[1],
805
            minibatch.compacted_negative_dsts,
806
        )
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
    (
        node_pairs,
        labels,
    ) = minibatch.node_pairs_with_labels
    if mode == "neg_src":
        expect_node_pairs = (
            torch.tensor([0, 1, 0, 0, 1, 1]),
            torch.tensor([1, 0, 1, 1, 0, 0]),
        )
    else:
        expect_node_pairs = (
            torch.tensor([0, 1, 0, 0, 1, 1]),
            torch.tensor([1, 0, 1, 0, 0, 1]),
        )
    expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
    assert torch.equal(node_pairs[0], expect_node_pairs[0])
    assert torch.equal(node_pairs[1], expect_node_pairs[1])
    assert torch.equal(labels, expect_labels)
825
826
827


@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
828
def test_dgl_link_predication_hetero(mode):
829
    # Arrange
830
    minibatch = create_hetero_minibatch()
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
    minibatch.compacted_node_pairs = {
        relation: (
            torch.tensor([1, 1]),
            torch.tensor([1, 0]),
        ),
        reverse_relation: (
            torch.tensor([0, 1]),
            torch.tensor([1, 0]),
        ),
    }
    if mode == "neg_graph" or mode == "neg_src":
        minibatch.compacted_negative_srcs = {
            relation: torch.tensor([[2, 0], [1, 2]]),
            reverse_relation: torch.tensor([[1, 2], [0, 2]]),
        }
    if mode == "neg_graph" or mode == "neg_dst":
        minibatch.compacted_negative_dsts = {
            relation: torch.tensor([[1, 3], [2, 1]]),
            reverse_relation: torch.tensor([[2, 1], [3, 1]]),
        }
    # Act
852
    dgl_blocks = minibatch.blocks
853
854

    # Assert
855
    assert len(dgl_blocks) == 2
856
    check_dgl_blocks_hetero(minibatch, dgl_blocks)
857
858
859
    if mode == "neg_graph" or mode == "neg_src":
        for etype, src in minibatch.compacted_negative_srcs.items():
            assert torch.equal(
860
                minibatch.negative_node_pairs[etype][0],
861
                src,
862
863
864
865
            )
    if mode == "neg_graph" or mode == "neg_dst":
        for etype, dst in minibatch.compacted_negative_dsts.items():
            assert torch.equal(
866
                minibatch.negative_node_pairs[etype][1],
867
                minibatch.compacted_negative_dsts[etype],
868
            )
869
870


871
def test_to_pyg_data_original():
872
873
874
875
    test_minibatch = create_homo_minibatch()
    test_minibatch.seed_nodes = torch.tensor([0, 1])
    test_minibatch.labels = torch.tensor([7, 8])

876
    expected_edge_index = torch.tensor(
877
        [[0, 0, 1, 1, 1, 2, 2, 2, 2], [0, 1, 0, 1, 2, 0, 1, 2, 3]]
878
    )
879
880
881
882
883
    expected_node_features = next(iter(test_minibatch.node_features.values()))
    expected_labels = torch.tensor([7, 8])
    expected_batch_size = 2
    expected_n_id = torch.tensor([10, 11, 12, 13])

884
885
886
887
888
    pyg_data = test_minibatch.to_pyg_data()
    pyg_data.validate()
    assert torch.equal(pyg_data.edge_index, expected_edge_index)
    assert torch.equal(pyg_data.x, expected_node_features)
    assert torch.equal(pyg_data.y, expected_labels)
889
890
    assert pyg_data.batch_size == expected_batch_size
    assert torch.equal(pyg_data.n_id, expected_n_id)
891

892
    subgraph = test_minibatch.sampled_subgraphs[0]
893
894
895
896
897
898
899
900
901
    # Test with sampled_csc as None.
    test_minibatch = gb.MiniBatch(
        sampled_subgraphs=None,
        node_features={"feat": expected_node_features},
        labels=expected_labels,
    )
    pyg_data = test_minibatch.to_pyg_data()
    assert pyg_data.edge_index is None, "Edge index should be none."

902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
    # Test with node_features as None.
    test_minibatch = gb.MiniBatch(
        sampled_subgraphs=[subgraph],
        node_features=None,
        labels=expected_labels,
    )
    pyg_data = test_minibatch.to_pyg_data()
    assert pyg_data.x is None, "Node features should be None."

    # Test with labels as None.
    test_minibatch = gb.MiniBatch(
        sampled_subgraphs=[subgraph],
        node_features={"feat": expected_node_features},
        labels=None,
    )
    pyg_data = test_minibatch.to_pyg_data()
    assert pyg_data.y is None, "Labels should be None."

    # Test with multiple features.
    test_minibatch = gb.MiniBatch(
        sampled_subgraphs=[subgraph],
        node_features={
            "feat": expected_node_features,
            "extra_feat": torch.tensor([[3], [4]]),
        },
        labels=expected_labels,
    )
    try:
        pyg_data = test_minibatch.to_pyg_data()
        assert (
            pyg_data.x is None
        ), "Multiple features case should raise an error."
    except AssertionError as e:
        assert (
            str(e)
            == "`to_pyg_data` only supports single feature homogeneous graph."
        )


def test_to_pyg_data():
    test_minibatch = create_homo_minibatch()
    test_minibatch.seeds = torch.tensor([0, 1])
    test_minibatch.labels = torch.tensor([7, 8])

    expected_edge_index = torch.tensor(
        [[0, 0, 1, 1, 1, 2, 2, 2, 2], [0, 1, 0, 1, 2, 0, 1, 2, 3]]
    )
    expected_node_features = next(iter(test_minibatch.node_features.values()))
    expected_labels = torch.tensor([7, 8])
    expected_batch_size = 2
    expected_n_id = torch.tensor([10, 11, 12, 13])

    pyg_data = test_minibatch.to_pyg_data()
    pyg_data.validate()
    assert torch.equal(pyg_data.edge_index, expected_edge_index)
    assert torch.equal(pyg_data.x, expected_node_features)
    assert torch.equal(pyg_data.y, expected_labels)
    assert pyg_data.batch_size == expected_batch_size
    assert torch.equal(pyg_data.n_id, expected_n_id)

    test_minibatch.seeds = torch.tensor([[0, 1], [2, 3]])
    assert pyg_data.batch_size == expected_batch_size

    test_minibatch.seeds = {"A": torch.tensor([0, 1])}
    assert pyg_data.batch_size == expected_batch_size

    test_minibatch.seeds = {"A": torch.tensor([[0, 1], [2, 3]])}
    assert pyg_data.batch_size == expected_batch_size

    subgraph = test_minibatch.sampled_subgraphs[0]
    # Test with sampled_csc as None.
    test_minibatch = gb.MiniBatch(
        sampled_subgraphs=None,
        node_features={"feat": expected_node_features},
        labels=expected_labels,
    )
    pyg_data = test_minibatch.to_pyg_data()
    assert pyg_data.edge_index is None, "Edge index should be none."

981
982
    # Test with node_features as None.
    test_minibatch = gb.MiniBatch(
983
        sampled_subgraphs=[subgraph],
984
985
986
987
988
989
990
991
        node_features=None,
        labels=expected_labels,
    )
    pyg_data = test_minibatch.to_pyg_data()
    assert pyg_data.x is None, "Node features should be None."

    # Test with labels as None.
    test_minibatch = gb.MiniBatch(
992
        sampled_subgraphs=[subgraph],
993
994
995
996
997
998
999
1000
        node_features={"feat": expected_node_features},
        labels=None,
    )
    pyg_data = test_minibatch.to_pyg_data()
    assert pyg_data.y is None, "Labels should be None."

    # Test with multiple features.
    test_minibatch = gb.MiniBatch(
1001
        sampled_subgraphs=[subgraph],
1002
1003
1004
1005
1006
1007
1008
1009
1010
        node_features={
            "feat": expected_node_features,
            "extra_feat": torch.tensor([[3], [4]]),
        },
        labels=expected_labels,
    )
    try:
        pyg_data = test_minibatch.to_pyg_data()
        assert (
1011
            pyg_data.x is None
1012
1013
1014
1015
1016
1017
        ), "Multiple features case should raise an error."
    except AssertionError as e:
        assert (
            str(e)
            == "`to_pyg_data` only supports single feature homogeneous graph."
        )