test_dataloader.py 1.15 KB
Newer Older
1
import backend as F
2

3
4
5
import dgl
import dgl.graphbolt
import torch
6
import torch.multiprocessing as mp
7
8

from . import gb_test_utils
9
10
11
12
13


def test_DataLoader():
    N = 40
    B = 4
14
    itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
15
    graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
16
17
18
19
20
    features = {}
    keys = [("node", None, "a"), ("node", None, "b")]
    features[keys[0]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
    features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
    feature_store = dgl.graphbolt.BasicFeatureStore(features)
21

22
    item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
23
    subgraph_sampler = dgl.graphbolt.NeighborSampler(
24
        item_sampler,
25
26
        graph,
        fanouts=[torch.LongTensor([2]) for _ in range(2)],
27
28
    )
    feature_fetcher = dgl.graphbolt.FeatureFetcher(
29
30
        subgraph_sampler,
        feature_store,
31
        ["a", "b"],
32
33
34
    )
    device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())

35
    dataloader = dgl.graphbolt.DataLoader(
36
37
38
39
        device_transferrer,
        num_workers=4,
    )
    assert len(list(dataloader)) == N // B