test_dataloader.py 2.43 KB
Newer Older
1
2
import unittest

3
import backend as F
4

5
6
import dgl
import dgl.graphbolt
7
import pytest
8
import torch
9
10

from . import gb_test_utils
11
12
13
14
15


def test_DataLoader():
    N = 40
    B = 4
16
    itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
17
    graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
18
19
20
21
22
    features = {}
    keys = [("node", None, "a"), ("node", None, "b")]
    features[keys[0]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
    features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
    feature_store = dgl.graphbolt.BasicFeatureStore(features)
23

24
    item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
25
    subgraph_sampler = dgl.graphbolt.NeighborSampler(
26
        item_sampler,
27
28
        graph,
        fanouts=[torch.LongTensor([2]) for _ in range(2)],
29
30
    )
    feature_fetcher = dgl.graphbolt.FeatureFetcher(
31
32
        subgraph_sampler,
        feature_store,
33
        ["a", "b"],
34
35
36
    )
    device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())

37
    dataloader = dgl.graphbolt.DataLoader(
38
39
40
41
        device_transferrer,
        num_workers=4,
    )
    assert len(list(dataloader)) == N // B
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82


@unittest.skipIf(
    F._default_context_str != "gpu",
    reason="This test requires the GPU.",
)
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch):
    N = 40
    B = 4
    itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
    graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
    features = {}
    keys = [("node", None, "a"), ("node", None, "b")]
    features[keys[0]] = dgl.graphbolt.TorchBasedFeature(
        torch.randn(200, 4, pin_memory=True)
    )
    features[keys[1]] = dgl.graphbolt.TorchBasedFeature(
        torch.randn(200, 4, pin_memory=True)
    )
    feature_store = dgl.graphbolt.BasicFeatureStore(features)

    datapipe = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
    datapipe = datapipe.copy_to(F.ctx(), extra_attrs=["seed_nodes"])
    datapipe = dgl.graphbolt.NeighborSampler(
        datapipe,
        graph,
        fanouts=[torch.LongTensor([2]) for _ in range(2)],
    )
    datapipe = dgl.graphbolt.FeatureFetcher(
        datapipe,
        feature_store,
        ["a", "b"],
    )

    dataloader = dgl.graphbolt.DataLoader(
        datapipe, overlap_feature_fetch=overlap_feature_fetch
    )
    assert len(list(dataloader)) == N // B