test_dataloader.py 1.28 KB
Newer Older
1
import backend as F
2

3
4
5
import dgl
import dgl.graphbolt
import torch
6
import torch.multiprocessing as mp
7
8

from . import gb_test_utils
9
10
11


def test_DataLoader():
12
13
14
    # https://pytorch.org/docs/master/notes/multiprocessing.html#cuda-in-multiprocessing
    mp.set_start_method("spawn", force=True)

15
16
    N = 40
    B = 4
17
    itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
18
    graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
19
20
21
22
23
    features = {}
    keys = [("node", None, "a"), ("node", None, "b")]
    features[keys[0]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
    features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
    feature_store = dgl.graphbolt.BasicFeatureStore(features)
24

25
    item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
26
    subgraph_sampler = dgl.graphbolt.NeighborSampler(
27
        item_sampler,
28
29
        graph,
        fanouts=[torch.LongTensor([2]) for _ in range(2)],
30
31
    )
    feature_fetcher = dgl.graphbolt.FeatureFetcher(
32
33
        subgraph_sampler,
        feature_store,
34
        ["a", "b"],
35
36
37
    )
    device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())

38
    dataloader = dgl.graphbolt.DataLoader(
39
40
41
42
        device_transferrer,
        num_workers=4,
    )
    assert len(list(dataloader)) == N // B