test_dataloader.py 3.06 KB
Newer Older
1
2
import unittest

3
import backend as F
4

5
6
import dgl
import dgl.graphbolt
7
import pytest
8
import torch
9

10
11
import torchdata.dataloader2.graph as dp_utils

12
from . import gb_test_utils
13
14
15
16
17


def test_DataLoader():
    N = 40
    B = 4
18
    itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
19
    graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
20
21
22
23
24
    features = {}
    keys = [("node", None, "a"), ("node", None, "b")]
    features[keys[0]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
    features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
    feature_store = dgl.graphbolt.BasicFeatureStore(features)
25

26
    item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
27
    subgraph_sampler = dgl.graphbolt.NeighborSampler(
28
        item_sampler,
29
30
        graph,
        fanouts=[torch.LongTensor([2]) for _ in range(2)],
31
32
    )
    feature_fetcher = dgl.graphbolt.FeatureFetcher(
33
34
        subgraph_sampler,
        feature_store,
35
        ["a", "b"],
36
37
38
    )
    device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())

39
    dataloader = dgl.graphbolt.DataLoader(
40
41
42
43
        device_transferrer,
        num_workers=4,
    )
    assert len(list(dataloader)) == N // B
44
45
46
47
48
49
50


@unittest.skipIf(
    F._default_context_str != "gpu",
    reason="This test requires the GPU.",
)
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
51
52
@pytest.mark.parametrize("enable_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    N = 40
    B = 4
    itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
    graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True).to(
        F.ctx()
    )
    features = {}
    keys = [("node", None, "a"), ("node", None, "b")]
    features[keys[0]] = dgl.graphbolt.TorchBasedFeature(
        torch.randn(200, 4, pin_memory=True)
    )
    features[keys[1]] = dgl.graphbolt.TorchBasedFeature(
        torch.randn(200, 4, pin_memory=True)
    )
    feature_store = dgl.graphbolt.BasicFeatureStore(features)

    datapipe = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
    datapipe = datapipe.copy_to(F.ctx(), extra_attrs=["seed_nodes"])
    datapipe = dgl.graphbolt.NeighborSampler(
        datapipe,
        graph,
        fanouts=[torch.LongTensor([2]) for _ in range(2)],
    )
76
77
78
79
80
81
    if enable_feature_fetch:
        datapipe = dgl.graphbolt.FeatureFetcher(
            datapipe,
            feature_store,
            ["a", "b"],
        )
82
83
84
85

    dataloader = dgl.graphbolt.DataLoader(
        datapipe, overlap_feature_fetch=overlap_feature_fetch
    )
86
87
88
89
90
91
92
93
94
95
96
97
98
    bufferer_awaiter_cnt = int(enable_feature_fetch and overlap_feature_fetch)
    datapipe = dataloader.dataset
    datapipe_graph = dp_utils.traverse_dps(datapipe)
    awaiters = dp_utils.find_dps(
        datapipe_graph,
        dgl.graphbolt.Awaiter,
    )
    assert len(awaiters) == bufferer_awaiter_cnt
    bufferers = dp_utils.find_dps(
        datapipe_graph,
        dgl.graphbolt.Bufferer,
    )
    assert len(bufferers) == bufferer_awaiter_cnt
99
    assert len(list(dataloader)) == N // B