test_dataloader.py 1.12 KB
Newer Older
1
import backend as F
2

3
4
5
import dgl
import dgl.graphbolt
import torch
6
7

from . import gb_test_utils
8
9
10
11
12


def test_DataLoader():
    N = 40
    B = 4
13
    itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
14
    graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
15
16
17
18
19
    features = {}
    keys = [("node", None, "a"), ("node", None, "b")]
    features[keys[0]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
    features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
    feature_store = dgl.graphbolt.BasicFeatureStore(features)
20

21
    item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
22
    subgraph_sampler = dgl.graphbolt.NeighborSampler(
23
        item_sampler,
24
25
        graph,
        fanouts=[torch.LongTensor([2]) for _ in range(2)],
26
27
    )
    feature_fetcher = dgl.graphbolt.FeatureFetcher(
28
29
        subgraph_sampler,
        feature_store,
30
        ["a", "b"],
31
32
33
    )
    device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())

34
    dataloader = dgl.graphbolt.DataLoader(
35
36
37
38
        device_transferrer,
        num_workers=4,
    )
    assert len(list(dataloader)) == N // B