test_subgraph_sampler.py 30.1 KB
Newer Older
1
import unittest
2
3

from enum import Enum
4
5
from functools import partial

6
7
import backend as F

8
import dgl
9
import dgl.graphbolt as gb
10
11
import pytest
import torch
12
from torchdata.datapipes.iter import Mapper
13

14
15
from . import gb_test_utils

16

17
18
19
20
21
22
# Skip all tests on GPU when sampling with TemporalNeighborSampler.
def _check_sampler_type(sampler_type):
    if F._default_context_str != "cpu" and sampler_type == SamplerType.Temporal:
        pytest.skip(
            "TemporalNeighborSampler sampling tests are only supported on CPU."
        )
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42


class SamplerType(Enum):
    Normal = 0
    Layer = 1
    Temporal = 2


def _get_sampler(sampler_type):
    if sampler_type == SamplerType.Normal:
        return gb.NeighborSampler
    if sampler_type == SamplerType.Layer:
        return gb.LayerNeighborSampler
    return partial(
        gb.TemporalNeighborSampler,
        node_timestamp_attr_name="timestamp",
        edge_timestamp_attr_name="timestamp",
    )


43
44
def test_SubgraphSampler_invoke():
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
45
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
46
47

    # Invoke via class constructor.
48
    datapipe = gb.SubgraphSampler(item_sampler)
49
50
51
52
    with pytest.raises(NotImplementedError):
        next(iter(datapipe))

    # Invokde via functional form.
53
    datapipe = item_sampler.sample_subgraph()
54
55
56
57
58
59
    with pytest.raises(NotImplementedError):
        next(iter(datapipe))


@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_invoke(labor):
60
61
62
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
63
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
64
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
65
66
67
68
69
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]

    # Invoke via class constructor.
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
70
    datapipe = Sampler(item_sampler, graph, fanouts)
71
72
73
74
    assert len(list(datapipe)) == 5

    # Invokde via functional form.
    if labor:
75
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
76
    else:
77
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
78
79
80
    assert len(list(datapipe)) == 5


81
82
@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_fanouts(labor):
83
84
85
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
86
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
87
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    num_layer = 2

    # `fanouts` is a list of tensors.
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    if labor:
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
    else:
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
    assert len(list(datapipe)) == 5

    # `fanouts` is a list of integers.
    fanouts = [2 for _ in range(num_layer)]
    if labor:
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
    else:
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
    assert len(list(datapipe)) == 5


107
108
109
110
111
@pytest.mark.parametrize(
    "sampler_type",
    [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Node(sampler_type):
112
    _check_sampler_type(sampler_type)
113
114
115
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
116
117
118
119
120
121
122
123
124
125
    items = torch.arange(10)
    names = "seed_nodes"
    if sampler_type == SamplerType.Temporal:
        graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
        graph.edge_attributes = {
            "timestamp": torch.arange(len(graph.indices)).to(F.ctx())
        }
        items = (items, torch.arange(10))
        names = ("seed_nodes", "timestamp")
    itemset = gb.ItemSet(items, names=names)
126
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
127
128
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
129
130
    sampler = _get_sampler(sampler_type)
    sampler_dp = sampler(item_sampler, graph, fanouts)
131
    assert len(list(sampler_dp)) == 5
132
133


134
def to_link_batch(data):
135
    block = gb.MiniBatch(node_pairs=data)
136
    return block
137
138


139
140
141
142
143
@pytest.mark.parametrize(
    "sampler_type",
    [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link(sampler_type):
144
    _check_sampler_type(sampler_type)
145
146
147
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
148
149
150
151
152
153
154
155
156
157
    items = torch.arange(20).reshape(-1, 2)
    names = "node_pairs"
    if sampler_type == SamplerType.Temporal:
        graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
        graph.edge_attributes = {
            "timestamp": torch.arange(len(graph.indices)).to(F.ctx())
        }
        items = (items, torch.arange(10))
        names = ("node_pairs", "timestamp")
    itemset = gb.ItemSet(items, names=names)
158
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
159
160
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
161
162
    sampler = _get_sampler(sampler_type)
    datapipe = sampler(datapipe, graph, fanouts)
163
164
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
165
166


167
168
169
170
171
@pytest.mark.parametrize(
    "sampler_type",
    [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_With_Negative(sampler_type):
172
    _check_sampler_type(sampler_type)
173
174
175
    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
176
177
178
179
180
181
182
183
184
185
    items = torch.arange(20).reshape(-1, 2)
    names = "node_pairs"
    if sampler_type == SamplerType.Temporal:
        graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
        graph.edge_attributes = {
            "timestamp": torch.arange(len(graph.indices)).to(F.ctx())
        }
        items = (items, torch.arange(10))
        names = ("node_pairs", "timestamp")
    itemset = gb.ItemSet(items, names=names)
186
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
187
188
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
189
    datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
190
191
    sampler = _get_sampler(sampler_type)
    datapipe = sampler(datapipe, graph, fanouts)
192
193
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
194
195


196
197
198
199
200
201
202
def get_hetero_graph():
    # COO graph:
    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
    # num_nodes = 5, num_n1 = 2, num_n2 = 3
    ntypes = {"n1": 0, "n2": 1}
203
    etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
204
205
206
207
    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
    node_type_offset = torch.LongTensor([0, 2, 5])
208
    return gb.fused_csc_sampling_graph(
209
210
211
212
        indptr,
        indices,
        node_type_offset=node_type_offset,
        type_per_edge=type_per_edge,
213
214
        node_type_to_id=ntypes,
        edge_type_to_id=etypes,
215
    )
216
217


218
219
220
221
222
@pytest.mark.parametrize(
    "sampler_type",
    [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Node_Hetero(sampler_type):
223
    _check_sampler_type(sampler_type)
224
    graph = get_hetero_graph().to(F.ctx())
225
226
227
228
229
230
231
232
233
234
235
236
    items = torch.arange(3)
    names = "seed_nodes"
    if sampler_type == SamplerType.Temporal:
        graph.node_attributes = {
            "timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
        }
        graph.edge_attributes = {
            "timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
        }
        items = (items, torch.randint(0, 10, (3,)))
        names = (names, "timestamp")
    itemset = gb.ItemSetDict({"n2": gb.ItemSet(items, names=names)})
237
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
238
239
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
240
241
    sampler = _get_sampler(sampler_type)
    sampler_dp = sampler(item_sampler, graph, fanouts)
242
243
    assert len(list(sampler_dp)) == 2
    for minibatch in sampler_dp:
peizhou001's avatar
peizhou001 committed
244
        assert len(minibatch.sampled_subgraphs) == num_layer
245
246


247
248
249
250
251
@pytest.mark.parametrize(
    "sampler_type",
    [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero(sampler_type):
252
    _check_sampler_type(sampler_type)
253
    graph = get_hetero_graph().to(F.ctx())
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
    first_names = "node_pairs"
    second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
    second_names = "node_pairs"
    if sampler_type == SamplerType.Temporal:
        graph.node_attributes = {
            "timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
        }
        graph.edge_attributes = {
            "timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
        }
        first_items = (first_items, torch.randint(0, 10, (4,)))
        first_names = (first_names, "timestamp")
        second_items = (second_items, torch.randint(0, 10, (6,)))
        second_names = (second_names, "timestamp")
269
270
    itemset = gb.ItemSetDict(
        {
271
            "n1:e1:n2": gb.ItemSet(
272
273
                first_items,
                names=first_names,
274
            ),
275
            "n2:e2:n1": gb.ItemSet(
276
277
                second_items,
                names=second_names,
278
279
280
            ),
        }
    )
281

282
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
283
284
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
285
286
    sampler = _get_sampler(sampler_type)
    datapipe = sampler(datapipe, graph, fanouts)
287
288
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
289
290


291
292
293
294
295
@pytest.mark.parametrize(
    "sampler_type",
    [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
296
    _check_sampler_type(sampler_type)
297
    graph = get_hetero_graph().to(F.ctx())
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
    first_names = "node_pairs"
    second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
    second_names = "node_pairs"
    if sampler_type == SamplerType.Temporal:
        graph.node_attributes = {
            "timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
        }
        graph.edge_attributes = {
            "timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
        }
        first_items = (first_items, torch.randint(0, 10, (4,)))
        first_names = (first_names, "timestamp")
        second_items = (second_items, torch.randint(0, 10, (6,)))
        second_names = (second_names, "timestamp")
313
314
    itemset = gb.ItemSetDict(
        {
315
            "n1:e1:n2": gb.ItemSet(
316
317
                first_items,
                names=first_names,
318
            ),
319
            "n2:e2:n1": gb.ItemSet(
320
321
                second_items,
                names=second_names,
322
323
324
325
            ),
        }
    )

326
    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
327
328
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
329
    datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
330
331
    sampler = _get_sampler(sampler_type)
    datapipe = sampler(datapipe, graph, fanouts)
332
333
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5
334
335


336
337
338
339
340
@pytest.mark.parametrize(
    "sampler_type",
    [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
341
    _check_sampler_type(sampler_type)
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    graph = get_hetero_graph().to(F.ctx())
    first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
    first_names = "node_pairs"
    second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
    second_names = "node_pairs"
    if sampler_type == SamplerType.Temporal:
        graph.node_attributes = {
            "timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
        }
        graph.edge_attributes = {
            "timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
        }
        first_items = (first_items, torch.randint(0, 10, (4,)))
        first_names = (first_names, "timestamp")
        second_items = (second_items, torch.randint(0, 10, (6,)))
        second_names = (second_names, "timestamp")
    # "e11" and "e22" are not valid edge types.
    itemset = gb.ItemSetDict(
        {
            "n1:e11:n2": gb.ItemSet(
                first_items,
                names=first_names,
            ),
            "n2:e22:n1": gb.ItemSet(
                second_items,
                names=second_names,
            ),
        }
    )

    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    sampler = _get_sampler(sampler_type)
    datapipe = sampler(datapipe, graph, fanouts)
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5


@pytest.mark.parametrize(
    "sampler_type",
    [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
386
    _check_sampler_type(sampler_type)
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    graph = get_hetero_graph().to(F.ctx())
    first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
    first_names = "node_pairs"
    second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
    second_names = "node_pairs"
    if sampler_type == SamplerType.Temporal:
        graph.node_attributes = {
            "timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
        }
        graph.edge_attributes = {
            "timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
        }
        first_items = (first_items, torch.randint(0, 10, (4,)))
        first_names = (first_names, "timestamp")
        second_items = (second_items, torch.randint(0, 10, (6,)))
        second_names = (second_names, "timestamp")
    # "e11" and "e22" are not valid edge types.
    itemset = gb.ItemSetDict(
        {
            "n1:e11:n2": gb.ItemSet(
                first_items,
                names=first_names,
            ),
            "n2:e22:n1": gb.ItemSet(
                second_items,
                names=second_names,
            ),
        }
    )

    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
    sampler = _get_sampler(sampler_type)
    datapipe = sampler(datapipe, graph, fanouts)
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
    assert len(list(datapipe)) == 5


427
428
429
430
@pytest.mark.parametrize(
    "sampler_type",
    [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
431
432
433
434
435
436
437
438
@pytest.mark.parametrize(
    "replace",
    [False, True],
)
def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):
    _check_sampler_type(sampler_type)
    if F._default_context_str == "gpu" and replace == True:
        pytest.skip("Sampling with replacement not yet supported on GPU.")
439
440
441
442
443
444
445
446
447
    num_nodes = 5
    num_edges = 9
    num_ntypes = 3
    num_etypes = 3
    (
        csc_indptr,
        indices,
        node_type_offset,
        type_per_edge,
448
449
        node_type_to_id,
        edge_type_to_id,
450
451
452
    ) = gb_test_utils.random_hetero_graph(
        num_nodes, num_edges, num_ntypes, num_etypes
    )
453
    node_attributes = {}
454
455
456
457
    edge_attributes = {
        "A1": torch.randn(num_edges),
        "A2": torch.randn(num_edges),
    }
458
459
460
    if sampler_type == SamplerType.Temporal:
        node_attributes["timestamp"] = torch.randint(0, 10, (num_nodes,))
        edge_attributes["timestamp"] = torch.randint(0, 10, (num_edges,))
461
    graph = gb.fused_csc_sampling_graph(
462
463
        csc_indptr,
        indices,
464
465
466
467
        node_type_offset=node_type_offset,
        type_per_edge=type_per_edge,
        node_type_to_id=node_type_to_id,
        edge_type_to_id=edge_type_to_id,
468
        node_attributes=node_attributes,
469
        edge_attributes=edge_attributes,
470
    ).to(F.ctx())
471
472
473
474
475
476
477
478
479
    first_items = torch.tensor([0])
    first_names = "seed_nodes"
    second_items = torch.tensor([0])
    second_names = "seed_nodes"
    if sampler_type == SamplerType.Temporal:
        first_items = (first_items, torch.randint(0, 10, (1,)))
        first_names = (first_names, "timestamp")
        second_items = (second_items, torch.randint(0, 10, (1,)))
        second_names = (second_names, "timestamp")
480
481
    itemset = gb.ItemSetDict(
        {
482
483
            "n2": gb.ItemSet(first_items, names=first_names),
            "n1": gb.ItemSet(second_items, names=second_names),
484
485
486
        }
    )

487
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
488
489
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
490
    sampler = _get_sampler(sampler_type)
491

492
    sampler_dp = sampler(item_sampler, graph, fanouts, replace=replace)
493
494
495

    for data in sampler_dp:
        for sampledsubgraph in data.sampled_subgraphs:
496
            for _, value in sampledsubgraph.sampled_csc.items():
497
                assert torch.equal(
498
499
500
501
502
                    torch.ge(
                        value.indices,
                        torch.zeros(len(value.indices)).to(F.ctx()),
                    ),
                    torch.ones(len(value.indices)).to(F.ctx()),
503
504
                )
                assert torch.equal(
505
506
507
508
                    torch.ge(
                        value.indptr, torch.zeros(len(value.indptr)).to(F.ctx())
                    ),
                    torch.ones(len(value.indptr)).to(F.ctx()),
509
510
511
                )
            for _, value in sampledsubgraph.original_column_node_ids.items():
                assert torch.equal(
512
513
                    torch.ge(value, torch.zeros(len(value)).to(F.ctx())),
                    torch.ones(len(value)).to(F.ctx()),
514
515
516
                )
            for _, value in sampledsubgraph.original_row_node_ids.items():
                assert torch.equal(
517
518
                    torch.ge(value, torch.zeros(len(value)).to(F.ctx())),
                    torch.ones(len(value)).to(F.ctx()),
519
                )
520
521


522
523
524
525
526
@pytest.mark.parametrize(
    "sampler_type",
    [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
527
    _check_sampler_type(sampler_type)
528
529
530
    graph = dgl.graph(
        ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
    )
531
    graph = gb.from_dglgraph(graph, True).to(F.ctx())
532
    seed_nodes = torch.LongTensor([0, 3, 4])
533
534
535
536
537
538
539
540
541
    items = seed_nodes
    names = "seed_nodes"
    if sampler_type == SamplerType.Temporal:
        graph.node_attributes = {
            "timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
        }
        graph.edge_attributes = {
            "timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
        }
542
        items = (items, torch.randint(1, 10, (3,)))
543
        names = (names, "timestamp")
544

545
    itemset = gb.ItemSet(items, names=names)
546
547
548
    item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
        F.ctx()
    )
549
550
551
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]

552
553
554
555
556
    sampler = _get_sampler(sampler_type)
    if sampler_type == SamplerType.Temporal:
        datapipe = sampler(item_sampler, graph, fanouts)
    else:
        datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
557
558
559

    length = [17, 7]
    compacted_indices = [
560
561
        (torch.arange(0, 10) + 7).to(F.ctx()),
        (torch.arange(0, 4) + 3).to(F.ctx()),
562
563
    ]
    indptr = [
564
565
566
567
        torch.tensor([0, 1, 2, 4, 4, 6, 8, 10]).to(F.ctx()),
        torch.tensor([0, 1, 2, 4]).to(F.ctx()),
    ]
    seeds = [
568
        torch.tensor([0, 2, 2, 3, 4, 4, 5]).to(F.ctx()),
569
        torch.tensor([0, 3, 4]).to(F.ctx()),
570
571
572
573
574
    ]
    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            assert len(sampled_subgraph.original_row_node_ids) == length[step]
            assert torch.equal(
575
576
577
578
                sampled_subgraph.sampled_csc.indices, compacted_indices[step]
            )
            assert torch.equal(
                sampled_subgraph.sampled_csc.indptr, indptr[step]
579
580
            )
            assert torch.equal(
581
582
                torch.sort(sampled_subgraph.original_column_node_ids)[0],
                seeds[step],
583
584
585
            )


586
587
588
589
590
@pytest.mark.parametrize(
    "sampler_type",
    [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type):
591
    _check_sampler_type(sampler_type)
592
    graph = get_hetero_graph().to(F.ctx())
593
594
595
596
597
598
599
600
601
    items = torch.arange(2)
    names = "seed_nodes"
    if sampler_type == SamplerType.Temporal:
        graph.node_attributes = {
            "timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
        }
        graph.edge_attributes = {
            "timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
        }
602
        items = (items, torch.randint(1, 10, (2,)))
603
604
        names = (names, "timestamp")
    itemset = gb.ItemSetDict({"n2": gb.ItemSet(items, names=names)})
605
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
606
607
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
608
609
610
611
612
    sampler = _get_sampler(sampler_type)
    if sampler_type == SamplerType.Temporal:
        datapipe = sampler(item_sampler, graph, fanouts)
    else:
        datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
    csc_formats = [
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([4, 5, 6, 7]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4, 6, 8]),
                indices=torch.tensor([2, 3, 4, 5, 6, 7, 8, 9]),
            ),
        },
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 1, 2, 3]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0]),
                indices=torch.tensor([], dtype=torch.int64),
            ),
        },
    ]
    original_column_node_ids = [
        {
            "n1": torch.tensor([0, 1, 1, 0]),
            "n2": torch.tensor([0, 1]),
        },
        {
            "n1": torch.tensor([], dtype=torch.int64),
            "n2": torch.tensor([0, 1]),
        },
    ]
    original_row_node_ids = [
        {
            "n1": torch.tensor([0, 1, 1, 0, 0, 1, 1, 0]),
            "n2": torch.tensor([0, 1, 0, 2, 0, 1, 0, 1, 0, 2]),
        },
        {
            "n1": torch.tensor([0, 1, 1, 0]),
            "n2": torch.tensor([0, 1]),
        },
    ]

    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            for ntype in ["n1", "n2"]:
                assert torch.equal(
                    sampled_subgraph.original_row_node_ids[ntype],
661
                    original_row_node_ids[step][ntype].to(F.ctx()),
662
663
664
                )
                assert torch.equal(
                    sampled_subgraph.original_column_node_ids[ntype],
665
                    original_column_node_ids[step][ntype].to(F.ctx()),
666
667
668
                )
            for etype in ["n1:e1:n2", "n2:e2:n1"]:
                assert torch.equal(
669
                    sampled_subgraph.sampled_csc[etype].indices,
670
                    csc_formats[step][etype].indices.to(F.ctx()),
671
672
                )
                assert torch.equal(
673
                    sampled_subgraph.sampled_csc[etype].indptr,
674
                    csc_formats[step][etype].indptr.to(F.ctx()),
675
                )
676
677


678
@unittest.skipIf(
679
680
    F._default_context_str == "gpu",
    reason="Fails due to different result on the GPU.",
681
)
682
@pytest.mark.parametrize("labor", [False, True])
683
def test_SubgraphSampler_unique_csc_format_Homo_cpu(labor):
684
685
    torch.manual_seed(1205)
    graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
686
    graph = gb.from_dglgraph(graph, True).to(F.ctx())
687
688
689
    seed_nodes = torch.LongTensor([0, 3, 4])

    itemset = gb.ItemSet(seed_nodes, names="seed_nodes")
690
691
692
    item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
        F.ctx()
    )
693
694
695
696
697
698
699
700
701
702
703
704
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]

    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    datapipe = Sampler(
        item_sampler,
        graph,
        fanouts,
        deduplicate=True,
    )

    original_row_node_ids = [
705
706
        torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),
        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
707
708
    ]
    compacted_indices = [
709
710
        torch.tensor([3, 4, 4, 2, 5, 6]).to(F.ctx()),
        torch.tensor([3, 4, 4, 2]).to(F.ctx()),
711
712
    ]
    indptr = [
713
714
715
716
717
718
        torch.tensor([0, 1, 2, 4, 4, 6]).to(F.ctx()),
        torch.tensor([0, 1, 2, 4]).to(F.ctx()),
    ]
    seeds = [
        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
        torch.tensor([0, 3, 4]).to(F.ctx()),
719
    ]
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            assert torch.equal(
                sampled_subgraph.original_row_node_ids,
                original_row_node_ids[step],
            )
            assert torch.equal(
                sampled_subgraph.sampled_csc.indices, compacted_indices[step]
            )
            assert torch.equal(
                sampled_subgraph.sampled_csc.indptr, indptr[step]
            )
            assert torch.equal(
                sampled_subgraph.original_column_node_ids, seeds[step]
            )


@unittest.skipIf(
    F._default_context_str == "cpu",
    reason="Fails due to different result on the CPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo_gpu(labor):
    torch.manual_seed(1205)
    graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))
    graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())
    seed_nodes = torch.LongTensor([0, 3, 4])

    itemset = gb.ItemSet(seed_nodes, names="seed_nodes")
    item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
        F.ctx()
    )
    num_layer = 2
    fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]

    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    datapipe = Sampler(
        item_sampler,
        graph,
        fanouts,
        deduplicate=True,
    )

    original_row_node_ids = [
        torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),
        torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
    ]
    compacted_indices = [
        torch.tensor([4, 3, 2, 5, 5]).to(F.ctx()),
        torch.tensor([4, 3, 2]).to(F.ctx()),
    ]
    indptr = [
        torch.tensor([0, 1, 2, 3, 5, 5]).to(F.ctx()),
        torch.tensor([0, 1, 2, 3]).to(F.ctx()),
    ]
    seeds = [
        torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
        torch.tensor([0, 3, 4]).to(F.ctx()),
    ]
779
780
781
782
783
784
785
    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            assert torch.equal(
                sampled_subgraph.original_row_node_ids,
                original_row_node_ids[step],
            )
            assert torch.equal(
786
787
788
789
                sampled_subgraph.sampled_csc.indices, compacted_indices[step]
            )
            assert torch.equal(
                sampled_subgraph.sampled_csc.indptr, indptr[step]
790
791
792
793
794
795
796
797
            )
            assert torch.equal(
                sampled_subgraph.original_column_node_ids, seeds[step]
            )


@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Hetero(labor):
798
    graph = get_hetero_graph().to(F.ctx())
799
800
801
    itemset = gb.ItemSetDict(
        {"n2": gb.ItemSet(torch.arange(2), names="seed_nodes")}
    )
802
    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    datapipe = Sampler(
        item_sampler,
        graph,
        fanouts,
        deduplicate=True,
    )
    csc_formats = [
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 1, 1, 0]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 2, 0, 1]),
            ),
        },
        {
            "n1:e1:n2": gb.CSCFormatBase(
                indptr=torch.tensor([0, 2, 4]),
                indices=torch.tensor([0, 1, 1, 0]),
            ),
            "n2:e2:n1": gb.CSCFormatBase(
                indptr=torch.tensor([0]),
                indices=torch.tensor([], dtype=torch.int64),
            ),
        },
    ]
    original_column_node_ids = [
        {
            "n1": torch.tensor([0, 1]),
            "n2": torch.tensor([0, 1]),
        },
        {
            "n1": torch.tensor([], dtype=torch.int64),
            "n2": torch.tensor([0, 1]),
        },
    ]
    original_row_node_ids = [
        {
            "n1": torch.tensor([0, 1]),
            "n2": torch.tensor([0, 1, 2]),
        },
        {
            "n1": torch.tensor([0, 1]),
            "n2": torch.tensor([0, 1]),
        },
    ]

    for data in datapipe:
        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
            for ntype in ["n1", "n2"]:
                assert torch.equal(
                    sampled_subgraph.original_row_node_ids[ntype],
860
                    original_row_node_ids[step][ntype].to(F.ctx()),
861
862
863
                )
                assert torch.equal(
                    sampled_subgraph.original_column_node_ids[ntype],
864
                    original_column_node_ids[step][ntype].to(F.ctx()),
865
866
867
                )
            for etype in ["n1:e1:n2", "n2:e2:n1"]:
                assert torch.equal(
868
                    sampled_subgraph.sampled_csc[etype].indices,
869
                    csc_formats[step][etype].indices.to(F.ctx()),
870
871
                )
                assert torch.equal(
872
                    sampled_subgraph.sampled_csc[etype].indptr,
873
                    csc_formats[step][etype].indptr.to(F.ctx()),
874
                )