test_subgraph_sampler.py 19.4 KB
Newer Older
1
import unittest
2
3
from functools import partial

4
5
import backend as F

6
import dgl
7
import dgl.graphbolt as gb
8
9
import pytest
import torch
10
from torchdata.datapipes.iter import Mapper
11

12
13
from . import gb_test_utils

14

15
16
def test_SubgraphSampler_invoke():
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
17
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
18
19

    # Invoke via class constructor.
20
    datapipe = gb.SubgraphSampler(item_sampler)
21
22
23
24
    with pytest.raises(NotImplementedError):
        next(iter(datapipe))

    # Invokde via functional form.
25
    datapipe = item_sampler.sample_subgraph()
26
27
28
29
30
31
    with pytest.raises(NotImplementedError):
        next(iter(datapipe))


@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_invoke(labor):
32
33
34
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
35
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
36
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
37
38
39
40
41
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]

    # Invoke via class constructor.
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
42
    datapipe = Sampler(item_sampler, graph, fanouts)
43
44
45
46
    assert len(list(datapipe)) == 5

    # Invokde via functional form.
    if labor:
47
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
48
    else:
49
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
50
51
52
    assert len(list(datapipe)) == 5


53
54
@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_fanouts(labor):
55
56
57
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
58
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
59
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    num_layer = 2

    # `fanouts` is a list of tensors.
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    if labor:
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
    else:
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
    assert len(list(datapipe)) == 5

    # `fanouts` is a list of integers.
    fanouts = [2 for _ in range(num_layer)]
    if labor:
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
    else:
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
    assert len(list(datapipe)) == 5


79
80
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor):
81
82
83
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
84
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
85
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
86
87
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
88
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
89
    sampler_dp = Sampler(item_sampler, graph, fanouts)
90
    assert len(list(sampler_dp)) == 5
91
92


93
def to_link_batch(data):
94
    block = gb.MiniBatch(node_pairs=data)
95
    return block
96
97


98
99
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link(labor):
100
101
102
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
103
    itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
104
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
105
106
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
107
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
108
109
110
    datapipe = Sampler(datapipe, graph, fanouts)
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
111
112


113
@pytest.mark.parametrize("labor", [False, True])
114
def test_SubgraphSampler_Link_With_Negative(labor):
115
116
117
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
118
    itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
119
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
120
121
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
122
    datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
123
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
124
125
126
    datapipe = Sampler(datapipe, graph, fanouts)
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
127
128


129
130
131
132
133
134
135
def get_hetero_graph():
    # COO graph:
    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
    # num_nodes = 5, num_n1 = 2, num_n2 = 3
    ntypes = {"n1": 0, "n2": 1}
136
    etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
137
138
139
140
    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
    node_type_offset = torch.LongTensor([0, 2, 5])
141
    return gb.fused_csc_sampling_graph(
142
143
144
145
        indptr,
        indices,
        node_type_offset=node_type_offset,
        type_per_edge=type_per_edge,
146
147
        node_type_to_id=ntypes,
        edge_type_to_id=etypes,
148
    )
149
150


151
152
153
154
@unittest.skipIf(
    F._default_context_str != "cpu",
    reason="Heterogenous sampling not yet supported on GPU.",
)
155
156
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node_Hetero(labor):
157
    graph = get_hetero_graph().to(F.ctx())
158
159
160
    itemset = gb.ItemSetDict(
        {"n2": gb.ItemSet(torch.arange(3), names="seed_nodes")}
    )
161
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
162
163
164
165
166
167
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    sampler_dp = Sampler(item_sampler, graph, fanouts)
    assert len(list(sampler_dp)) == 2
    for minibatch in sampler_dp:
peizhou001's avatar
peizhou001 committed
168
        assert len(minibatch.sampled_subgraphs) == num_layer
169
170


171
172
173
174
@unittest.skipIf(
    F._default_context_str != "cpu",
    reason="Heterogenous sampling not yet supported on GPU.",
)
175
176
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero(labor):
177
    graph = get_hetero_graph().to(F.ctx())
178
179
    itemset = gb.ItemSetDict(
        {
180
            "n1:e1:n2": gb.ItemSet(
181
182
                torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
                names="node_pairs",
183
            ),
184
            "n2:e2:n1": gb.ItemSet(
185
186
                torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
                names="node_pairs",
187
188
189
            ),
        }
    )
190

191
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
192
193
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
194
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
195
196
197
    datapipe = Sampler(datapipe, graph, fanouts)
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
198
199


200
201
202
203
@unittest.skipIf(
    F._default_context_str != "cpu",
    reason="Heterogenous sampling not yet supported on GPU.",
)
204
@pytest.mark.parametrize("labor", [False, True])
205
def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
206
    graph = get_hetero_graph().to(F.ctx())
207
208
    itemset = gb.ItemSetDict(
        {
209
            "n1:e1:n2": gb.ItemSet(
210
211
                torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
                names="node_pairs",
212
            ),
213
            "n2:e2:n1": gb.ItemSet(
214
215
                torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
                names="node_pairs",
216
217
218
219
            ),
        }
    )

220
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
221
222
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
223
    datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
224
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
225
226
227
    datapipe = Sampler(datapipe, graph, fanouts)
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
228
229


230
231
232
233
@unittest.skipIf(
    F._default_context_str != "cpu",
    reason="Sampling with replacement not yet supported on GPU.",
)
234
235
236
237
238
239
240
241
242
243
244
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Random_Hetero_Graph(labor):
    num_nodes = 5
    num_edges = 9
    num_ntypes = 3
    num_etypes = 3
    (
        csc_indptr,
        indices,
        node_type_offset,
        type_per_edge,
245
246
        node_type_to_id,
        edge_type_to_id,
247
248
249
250
251
252
253
    ) = gb_test_utils.random_hetero_graph(
        num_nodes, num_edges, num_ntypes, num_etypes
    )
    edge_attributes = {
        "A1": torch.randn(num_edges),
        "A2": torch.randn(num_edges),
    }
254
    graph = gb.fused_csc_sampling_graph(
255
256
        csc_indptr,
        indices,
257
258
259
260
261
        node_type_offset=node_type_offset,
        type_per_edge=type_per_edge,
        node_type_to_id=node_type_to_id,
        edge_type_to_id=edge_type_to_id,
        edge_attributes=edge_attributes,
262
    ).to(F.ctx())
263
264
265
    itemset = gb.ItemSetDict(
        {
            "n2": gb.ItemSet(torch.tensor([0]), names="seed_nodes"),
266
            "n1": gb.ItemSet(torch.tensor([0]), names="seed_nodes"),
267
268
269
        }
    )

270
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
271
272
273
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
274

275
276
277
278
    sampler_dp = Sampler(item_sampler, graph, fanouts, replace=True)

    for data in sampler_dp:
        for sampledsubgraph in data.sampled_subgraphs:
279
            for _, value in sampledsubgraph.sampled_csc.items():
280
                assert torch.equal(
281
282
                    torch.ge(value.indices, torch.zeros(len(value.indices))),
                    torch.ones(len(value.indices)),
283
284
                )
                assert torch.equal(
285
286
                    torch.ge(value.indptr, torch.zeros(len(value.indptr))),
                    torch.ones(len(value.indptr)),
287
288
289
290
291
292
293
294
295
296
297
                )
            for _, value in sampledsubgraph.original_column_node_ids.items():
                assert torch.equal(
                    torch.ge(value, torch.zeros(len(value))),
                    torch.ones(len(value)),
                )
            for _, value in sampledsubgraph.original_row_node_ids.items():
                assert torch.equal(
                    torch.ge(value, torch.zeros(len(value))),
                    torch.ones(len(value)),
                )
298
299


300
301
302
303
@unittest.skipIf(
    F._default_context_str != "cpu",
    reason="Fails due to randomness on the GPU.",
)
304
305
306
307
308
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_without_dedpulication_Homo(labor):
    graph = dgl.graph(
        ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
    )
309
    graph = gb.from_dglgraph(graph, True).to(F.ctx())
310
311
312
    seed_nodes = torch.LongTensor([0, 3, 4])

    itemset = gb.ItemSet(seed_nodes, names="seed_nodes")
313
314
315
    item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
        F.ctx()
    )
316
317
318
319
320
321
322
323
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]

    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    datapipe = Sampler(item_sampler, graph, fanouts, deduplicate=False)

    length = [17, 7]
    compacted_indices = [
324
325
        (torch.arange(0, 10) + 7).to(F.ctx()),
        (torch.arange(0, 4) + 3).to(F.ctx()),
326
327
    ]
    indptr = [
328
329
330
331
332
333
        torch.tensor([0, 1, 2, 4, 4, 6, 8, 10]).to(F.ctx()),
        torch.tensor([0, 1, 2, 4]).to(F.ctx()),
    ]
    seeds = [
        torch.tensor([0, 3, 4, 5, 2, 2, 4]).to(F.ctx()),
        torch.tensor([0, 3, 4]).to(F.ctx()),
334
335
336
337
338
    ]
    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            assert len(sampled_subgraph.original_row_node_ids) == length[step]
            assert torch.equal(
339
340
341
342
                sampled_subgraph.sampled_csc.indices, compacted_indices[step]
            )
            assert torch.equal(
                sampled_subgraph.sampled_csc.indptr, indptr[step]
343
344
345
346
347
348
            )
            assert torch.equal(
                sampled_subgraph.original_column_node_ids, seeds[step]
            )


349
350
351
352
@unittest.skipIf(
    F._default_context_str != "cpu",
    reason="Heterogenous sampling not yet supported on GPU.",
)
353
354
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_without_dedpulication_Hetero(labor):
355
    graph = get_hetero_graph().to(F.ctx())
356
357
358
    itemset = gb.ItemSetDict(
        {"n2": gb.ItemSet(torch.arange(2), names="seed_nodes")}
    )
359
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    datapipe = Sampler(item_sampler, graph, fanouts, deduplicate=False)
    csc_formats = [
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([4, 5, 6, 7]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4, 6, 8]),
                indices=torch.tensor([2, 3, 4, 5, 6, 7, 8, 9]),
            ),
        },
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 1, 2, 3]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0]),
                indices=torch.tensor([], dtype=torch.int64),
            ),
        },
    ]
    original_column_node_ids = [
        {
            "n1": torch.tensor([0, 1, 1, 0]),
            "n2": torch.tensor([0, 1]),
        },
        {
            "n1": torch.tensor([], dtype=torch.int64),
            "n2": torch.tensor([0, 1]),
        },
    ]
    original_row_node_ids = [
        {
            "n1": torch.tensor([0, 1, 1, 0, 0, 1, 1, 0]),
            "n2": torch.tensor([0, 1, 0, 2, 0, 1, 0, 1, 0, 2]),
        },
        {
            "n1": torch.tensor([0, 1, 1, 0]),
            "n2": torch.tensor([0, 1]),
        },
    ]

    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            for ntype in ["n1", "n2"]:
                assert torch.equal(
                    sampled_subgraph.original_row_node_ids[ntype],
                    original_row_node_ids[step][ntype],
                )
                assert torch.equal(
                    sampled_subgraph.original_column_node_ids[ntype],
                    original_column_node_ids[step][ntype],
                )
            for etype in ["n1:e1:n2", "n2:e2:n1"]:
                assert torch.equal(
420
                    sampled_subgraph.sampled_csc[etype].indices,
421
422
423
                    csc_formats[step][etype].indices,
                )
                assert torch.equal(
424
                    sampled_subgraph.sampled_csc[etype].indptr,
425
426
                    csc_formats[step][etype].indptr,
                )
427
428


429
430
431
432
@unittest.skipIf(
    F._default_context_str != "cpu",
    reason="Fails due to randomness on the GPU.",
)
433
434
435
436
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo(labor):
    torch.manual_seed(1205)
    graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
437
    graph = gb.from_dglgraph(graph, True).to(F.ctx())
438
439
440
    seed_nodes = torch.LongTensor([0, 3, 4])

    itemset = gb.ItemSet(seed_nodes, names="seed_nodes")
441
442
443
    item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
        F.ctx()
    )
444
445
446
447
448
449
450
451
452
453
454
455
456
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]

    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    datapipe = Sampler(
        item_sampler,
        graph,
        fanouts,
        replace=False,
        deduplicate=True,
    )

    original_row_node_ids = [
457
458
        torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),
        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
459
460
    ]
    compacted_indices = [
461
462
        torch.tensor([3, 4, 4, 2, 5, 6]).to(F.ctx()),
        torch.tensor([3, 4, 4, 2]).to(F.ctx()),
463
464
    ]
    indptr = [
465
466
467
468
469
470
        torch.tensor([0, 1, 2, 4, 4, 6]).to(F.ctx()),
        torch.tensor([0, 1, 2, 4]).to(F.ctx()),
    ]
    seeds = [
        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
        torch.tensor([0, 3, 4]).to(F.ctx()),
471
472
473
474
475
476
477
478
    ]
    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            assert torch.equal(
                sampled_subgraph.original_row_node_ids,
                original_row_node_ids[step],
            )
            assert torch.equal(
479
480
481
482
                sampled_subgraph.sampled_csc.indices, compacted_indices[step]
            )
            assert torch.equal(
                sampled_subgraph.sampled_csc.indptr, indptr[step]
483
484
485
486
487
488
            )
            assert torch.equal(
                sampled_subgraph.original_column_node_ids, seeds[step]
            )


489
490
491
492
@unittest.skipIf(
    F._default_context_str != "cpu",
    reason="Heterogenous sampling not yet supported on GPU.",
)
493
494
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Hetero(labor):
495
    graph = get_hetero_graph().to(F.ctx())
496
497
498
    itemset = gb.ItemSetDict(
        {"n2": gb.ItemSet(torch.arange(2), names="seed_nodes")}
    )
499
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    datapipe = Sampler(
        item_sampler,
        graph,
        fanouts,
        deduplicate=True,
    )
    csc_formats = [
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 1, 1, 0]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 2, 0, 1]),
            ),
        },
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 1, 1, 0]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0]),
                indices=torch.tensor([], dtype=torch.int64),
            ),
        },
    ]
    original_column_node_ids = [
        {
            "n1": torch.tensor([0, 1]),
            "n2": torch.tensor([0, 1]),
        },
        {
            "n1": torch.tensor([], dtype=torch.int64),
            "n2": torch.tensor([0, 1]),
        },
    ]
    original_row_node_ids = [
        {
            "n1": torch.tensor([0, 1]),
            "n2": torch.tensor([0, 1, 2]),
        },
        {
            "n1": torch.tensor([0, 1]),
            "n2": torch.tensor([0, 1]),
        },
    ]

    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            for ntype in ["n1", "n2"]:
                assert torch.equal(
                    sampled_subgraph.original_row_node_ids[ntype],
                    original_row_node_ids[step][ntype],
                )
                assert torch.equal(
                    sampled_subgraph.original_column_node_ids[ntype],
                    original_column_node_ids[step][ntype],
                )
            for etype in ["n1:e1:n2", "n2:e2:n1"]:
                assert torch.equal(
565
                    sampled_subgraph.sampled_csc[etype].indices,
566
567
568
                    csc_formats[step][etype].indices,
                )
                assert torch.equal(
569
                    sampled_subgraph.sampled_csc[etype].indptr,
570
571
                    csc_formats[step][etype].indptr,
                )