test_subgraph_sampler.py 18.9 KB
Newer Older
1
import unittest
2
3
from functools import partial

4
5
import backend as F

6
import dgl
7
import dgl.graphbolt as gb
8
9
import pytest
import torch
10
from torchdata.datapipes.iter import Mapper
11

12
13
from . import gb_test_utils

14

15
16
def test_SubgraphSampler_invoke():
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
17
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
18
19

    # Invoke via class constructor.
20
    datapipe = gb.SubgraphSampler(item_sampler)
21
22
23
24
    with pytest.raises(NotImplementedError):
        next(iter(datapipe))

    # Invokde via functional form.
25
    datapipe = item_sampler.sample_subgraph()
26
27
28
29
30
31
    with pytest.raises(NotImplementedError):
        next(iter(datapipe))


@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_invoke(labor):
32
33
34
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
35
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
36
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
37
38
39
40
41
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]

    # Invoke via class constructor.
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
42
    datapipe = Sampler(item_sampler, graph, fanouts)
43
44
45
46
    assert len(list(datapipe)) == 5

    # Invokde via functional form.
    if labor:
47
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
48
    else:
49
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
50
51
52
    assert len(list(datapipe)) == 5


53
54
@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_fanouts(labor):
55
56
57
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
58
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
59
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    num_layer = 2

    # `fanouts` is a list of tensors.
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    if labor:
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
    else:
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
    assert len(list(datapipe)) == 5

    # `fanouts` is a list of integers.
    fanouts = [2 for _ in range(num_layer)]
    if labor:
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
    else:
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
    assert len(list(datapipe)) == 5


79
80
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor):
81
82
83
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
84
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
85
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
86
87
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
88
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
89
    sampler_dp = Sampler(item_sampler, graph, fanouts)
90
    assert len(list(sampler_dp)) == 5
91
92


93
def to_link_batch(data):
94
    block = gb.MiniBatch(node_pairs=data)
95
    return block
96
97


98
99
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link(labor):
100
101
102
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
103
    itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
104
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
105
106
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
107
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
108
109
110
    datapipe = Sampler(datapipe, graph, fanouts)
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
111
112


113
@pytest.mark.parametrize("labor", [False, True])
114
def test_SubgraphSampler_Link_With_Negative(labor):
115
116
117
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
118
    itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
119
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
120
121
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
122
    datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
123
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
124
125
126
    datapipe = Sampler(datapipe, graph, fanouts)
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
127
128


129
130
131
132
133
134
135
def get_hetero_graph():
    # COO graph:
    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
    # num_nodes = 5, num_n1 = 2, num_n2 = 3
    ntypes = {"n1": 0, "n2": 1}
136
    etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
137
138
139
140
    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
    node_type_offset = torch.LongTensor([0, 2, 5])
141
    return gb.fused_csc_sampling_graph(
142
143
144
145
        indptr,
        indices,
        node_type_offset=node_type_offset,
        type_per_edge=type_per_edge,
146
147
        node_type_to_id=ntypes,
        edge_type_to_id=etypes,
148
    )
149
150


151
152
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node_Hetero(labor):
153
    graph = get_hetero_graph().to(F.ctx())
154
155
156
    itemset = gb.ItemSetDict(
        {"n2": gb.ItemSet(torch.arange(3), names="seed_nodes")}
    )
157
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
158
159
160
161
162
163
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    sampler_dp = Sampler(item_sampler, graph, fanouts)
    assert len(list(sampler_dp)) == 2
    for minibatch in sampler_dp:
peizhou001's avatar
peizhou001 committed
164
        assert len(minibatch.sampled_subgraphs) == num_layer
165
166


167
168
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero(labor):
169
    graph = get_hetero_graph().to(F.ctx())
170
171
    itemset = gb.ItemSetDict(
        {
172
            "n1:e1:n2": gb.ItemSet(
173
174
                torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
                names="node_pairs",
175
            ),
176
            "n2:e2:n1": gb.ItemSet(
177
178
                torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
                names="node_pairs",
179
180
181
            ),
        }
    )
182

183
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
184
185
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
186
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
187
188
189
    datapipe = Sampler(datapipe, graph, fanouts)
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
190
191


192
@pytest.mark.parametrize("labor", [False, True])
193
def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
194
    graph = get_hetero_graph().to(F.ctx())
195
196
    itemset = gb.ItemSetDict(
        {
197
            "n1:e1:n2": gb.ItemSet(
198
199
                torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
                names="node_pairs",
200
            ),
201
            "n2:e2:n1": gb.ItemSet(
202
203
                torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
                names="node_pairs",
204
205
206
207
            ),
        }
    )

208
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
209
210
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
211
    datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
212
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
213
214
215
    datapipe = Sampler(datapipe, graph, fanouts)
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
216
217


218
219
220
221
@unittest.skipIf(
    F._default_context_str != "cpu",
    reason="Sampling with replacement not yet supported on GPU.",
)
222
223
224
225
226
227
228
229
230
231
232
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Random_Hetero_Graph(labor):
    num_nodes = 5
    num_edges = 9
    num_ntypes = 3
    num_etypes = 3
    (
        csc_indptr,
        indices,
        node_type_offset,
        type_per_edge,
233
234
        node_type_to_id,
        edge_type_to_id,
235
236
237
238
239
240
241
    ) = gb_test_utils.random_hetero_graph(
        num_nodes, num_edges, num_ntypes, num_etypes
    )
    edge_attributes = {
        "A1": torch.randn(num_edges),
        "A2": torch.randn(num_edges),
    }
242
    graph = gb.fused_csc_sampling_graph(
243
244
        csc_indptr,
        indices,
245
246
247
248
249
        node_type_offset=node_type_offset,
        type_per_edge=type_per_edge,
        node_type_to_id=node_type_to_id,
        edge_type_to_id=edge_type_to_id,
        edge_attributes=edge_attributes,
250
    ).to(F.ctx())
251
252
253
    itemset = gb.ItemSetDict(
        {
            "n2": gb.ItemSet(torch.tensor([0]), names="seed_nodes"),
254
            "n1": gb.ItemSet(torch.tensor([0]), names="seed_nodes"),
255
256
257
        }
    )

258
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
259
260
261
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
262

263
264
265
266
    sampler_dp = Sampler(item_sampler, graph, fanouts, replace=True)

    for data in sampler_dp:
        for sampledsubgraph in data.sampled_subgraphs:
267
            for _, value in sampledsubgraph.sampled_csc.items():
268
                assert torch.equal(
269
270
                    torch.ge(value.indices, torch.zeros(len(value.indices))),
                    torch.ones(len(value.indices)),
271
272
                )
                assert torch.equal(
273
274
                    torch.ge(value.indptr, torch.zeros(len(value.indptr))),
                    torch.ones(len(value.indptr)),
275
276
277
278
279
280
281
282
283
284
285
                )
            for _, value in sampledsubgraph.original_column_node_ids.items():
                assert torch.equal(
                    torch.ge(value, torch.zeros(len(value))),
                    torch.ones(len(value)),
                )
            for _, value in sampledsubgraph.original_row_node_ids.items():
                assert torch.equal(
                    torch.ge(value, torch.zeros(len(value))),
                    torch.ones(len(value)),
                )
286
287


288
289
290
291
@unittest.skipIf(
    F._default_context_str != "cpu",
    reason="Fails due to randomness on the GPU.",
)
292
293
294
295
296
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_without_dedpulication_Homo(labor):
    graph = dgl.graph(
        ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
    )
297
    graph = gb.from_dglgraph(graph, True).to(F.ctx())
298
299
300
    seed_nodes = torch.LongTensor([0, 3, 4])

    itemset = gb.ItemSet(seed_nodes, names="seed_nodes")
301
302
303
    item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
        F.ctx()
    )
304
305
306
307
308
309
310
311
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]

    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    datapipe = Sampler(item_sampler, graph, fanouts, deduplicate=False)

    length = [17, 7]
    compacted_indices = [
312
313
        (torch.arange(0, 10) + 7).to(F.ctx()),
        (torch.arange(0, 4) + 3).to(F.ctx()),
314
315
    ]
    indptr = [
316
317
318
319
320
321
        torch.tensor([0, 1, 2, 4, 4, 6, 8, 10]).to(F.ctx()),
        torch.tensor([0, 1, 2, 4]).to(F.ctx()),
    ]
    seeds = [
        torch.tensor([0, 3, 4, 5, 2, 2, 4]).to(F.ctx()),
        torch.tensor([0, 3, 4]).to(F.ctx()),
322
323
324
325
326
    ]
    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            assert len(sampled_subgraph.original_row_node_ids) == length[step]
            assert torch.equal(
327
328
329
330
                sampled_subgraph.sampled_csc.indices, compacted_indices[step]
            )
            assert torch.equal(
                sampled_subgraph.sampled_csc.indptr, indptr[step]
331
332
333
334
335
336
337
338
            )
            assert torch.equal(
                sampled_subgraph.original_column_node_ids, seeds[step]
            )


@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_without_dedpulication_Hetero(labor):
339
    graph = get_hetero_graph().to(F.ctx())
340
341
342
    itemset = gb.ItemSetDict(
        {"n2": gb.ItemSet(torch.arange(2), names="seed_nodes")}
    )
343
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    datapipe = Sampler(item_sampler, graph, fanouts, deduplicate=False)
    csc_formats = [
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([4, 5, 6, 7]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4, 6, 8]),
                indices=torch.tensor([2, 3, 4, 5, 6, 7, 8, 9]),
            ),
        },
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 1, 2, 3]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0]),
                indices=torch.tensor([], dtype=torch.int64),
            ),
        },
    ]
    original_column_node_ids = [
        {
            "n1": torch.tensor([0, 1, 1, 0]),
            "n2": torch.tensor([0, 1]),
        },
        {
            "n1": torch.tensor([], dtype=torch.int64),
            "n2": torch.tensor([0, 1]),
        },
    ]
    original_row_node_ids = [
        {
            "n1": torch.tensor([0, 1, 1, 0, 0, 1, 1, 0]),
            "n2": torch.tensor([0, 1, 0, 2, 0, 1, 0, 1, 0, 2]),
        },
        {
            "n1": torch.tensor([0, 1, 1, 0]),
            "n2": torch.tensor([0, 1]),
        },
    ]

    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            for ntype in ["n1", "n2"]:
                assert torch.equal(
                    sampled_subgraph.original_row_node_ids[ntype],
396
                    original_row_node_ids[step][ntype].to(F.ctx()),
397
398
399
                )
                assert torch.equal(
                    sampled_subgraph.original_column_node_ids[ntype],
400
                    original_column_node_ids[step][ntype].to(F.ctx()),
401
402
403
                )
            for etype in ["n1:e1:n2", "n2:e2:n1"]:
                assert torch.equal(
404
                    sampled_subgraph.sampled_csc[etype].indices,
405
                    csc_formats[step][etype].indices.to(F.ctx()),
406
407
                )
                assert torch.equal(
408
                    sampled_subgraph.sampled_csc[etype].indptr,
409
                    csc_formats[step][etype].indptr.to(F.ctx()),
410
                )
411
412


413
414
415
416
@unittest.skipIf(
    F._default_context_str != "cpu",
    reason="Fails due to randomness on the GPU.",
)
417
418
419
420
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo(labor):
    torch.manual_seed(1205)
    graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
421
    graph = gb.from_dglgraph(graph, True).to(F.ctx())
422
423
424
    seed_nodes = torch.LongTensor([0, 3, 4])

    itemset = gb.ItemSet(seed_nodes, names="seed_nodes")
425
426
427
    item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
        F.ctx()
    )
428
429
430
431
432
433
434
435
436
437
438
439
440
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]

    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    datapipe = Sampler(
        item_sampler,
        graph,
        fanouts,
        replace=False,
        deduplicate=True,
    )

    original_row_node_ids = [
441
442
        torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),
        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
443
444
    ]
    compacted_indices = [
445
446
        torch.tensor([3, 4, 4, 2, 5, 6]).to(F.ctx()),
        torch.tensor([3, 4, 4, 2]).to(F.ctx()),
447
448
    ]
    indptr = [
449
450
451
452
453
454
        torch.tensor([0, 1, 2, 4, 4, 6]).to(F.ctx()),
        torch.tensor([0, 1, 2, 4]).to(F.ctx()),
    ]
    seeds = [
        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
        torch.tensor([0, 3, 4]).to(F.ctx()),
455
456
457
458
459
460
461
462
    ]
    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            assert torch.equal(
                sampled_subgraph.original_row_node_ids,
                original_row_node_ids[step],
            )
            assert torch.equal(
463
464
465
466
                sampled_subgraph.sampled_csc.indices, compacted_indices[step]
            )
            assert torch.equal(
                sampled_subgraph.sampled_csc.indptr, indptr[step]
467
468
469
470
471
472
473
474
            )
            assert torch.equal(
                sampled_subgraph.original_column_node_ids, seeds[step]
            )


@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Hetero(labor):
475
    graph = get_hetero_graph().to(F.ctx())
476
477
478
    itemset = gb.ItemSetDict(
        {"n2": gb.ItemSet(torch.arange(2), names="seed_nodes")}
    )
479
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    datapipe = Sampler(
        item_sampler,
        graph,
        fanouts,
        deduplicate=True,
    )
    csc_formats = [
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 1, 1, 0]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 2, 0, 1]),
            ),
        },
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 1, 1, 0]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0]),
                indices=torch.tensor([], dtype=torch.int64),
            ),
        },
    ]
    original_column_node_ids = [
        {
            "n1": torch.tensor([0, 1]),
            "n2": torch.tensor([0, 1]),
        },
        {
            "n1": torch.tensor([], dtype=torch.int64),
            "n2": torch.tensor([0, 1]),
        },
    ]
    original_row_node_ids = [
        {
            "n1": torch.tensor([0, 1]),
            "n2": torch.tensor([0, 1, 2]),
        },
        {
            "n1": torch.tensor([0, 1]),
            "n2": torch.tensor([0, 1]),
        },
    ]

    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            for ntype in ["n1", "n2"]:
                assert torch.equal(
                    sampled_subgraph.original_row_node_ids[ntype],
537
                    original_row_node_ids[step][ntype].to(F.ctx()),
538
539
540
                )
                assert torch.equal(
                    sampled_subgraph.original_column_node_ids[ntype],
541
                    original_column_node_ids[step][ntype].to(F.ctx()),
542
543
544
                )
            for etype in ["n1:e1:n2", "n2:e2:n1"]:
                assert torch.equal(
545
                    sampled_subgraph.sampled_csc[etype].indices,
546
                    csc_formats[step][etype].indices.to(F.ctx()),
547
548
                )
                assert torch.equal(
549
                    sampled_subgraph.sampled_csc[etype].indptr,
550
                    csc_formats[step][etype].indptr.to(F.ctx()),
551
                )