Unverified Commit 8adb53bb authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[GraphBolt] Add single process dataloader and ux (#5941)

parent 39890c0c
...@@ -12,6 +12,7 @@ from .feature_store import * ...@@ -12,6 +12,7 @@ from .feature_store import *
from .feature_fetcher import * from .feature_fetcher import *
from .copy_to import * from .copy_to import *
from .dataset import * from .dataset import *
from .dataloader import *
from .subgraph_sampler import * from .subgraph_sampler import *
......
"""Graph Bolt DataLoaders""" """Graph Bolt DataLoaders"""
import torch.utils.data
class SingleProcessDataLoader(torch.utils.data.DataLoader):
"""Single process DataLoader.
Iterates over the data pipeline in the main process.
Parameters
----------
datapipe : DataPipe
The data pipeline.
"""
# In the single process dataloader case, we don't need to do any
# modifications to the datapipe, and we just PyTorch's native
# dataloader as-is.
#
# The exception is that batch_size should be None, since we already
# have minibatch sampling and collating in MinibatchSampler.
def __init__(self, datapipe):
super().__init__(datapipe, batch_size=None, num_workers=0)
import backend as F
import dgl
import dgl.graphbolt
import gb_test_utils
import torch
def test_DataLoader():
N = 32
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N))
graph = gb_test_utils.rand_csc_graph(200, 0.15)
features = dgl.graphbolt.feature_store.TorchBasedFeatureStore(
torch.randn(200, 4)
)
labels = dgl.graphbolt.feature_store.TorchBasedFeatureStore(
torch.randint(0, 10, (200,))
)
def sampler_func(data):
adjs = []
seeds = data
for hop in range(2):
sg = graph.sample_neighbors(seeds, torch.LongTensor([2]))
seeds = sg.indices
adjs.insert(0, sg)
input_nodes = seeds
output_nodes = data
return input_nodes, output_nodes, adjs
def fetch_func(data):
input_nodes, output_nodes, adjs = data
input_features = features.read(input_nodes)
output_labels = labels.read(output_nodes)
return input_features, output_labels, adjs
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B)
subgraph_sampler = dgl.graphbolt.SubgraphSampler(
minibatch_sampler,
sampler_func,
)
feature_fetcher = dgl.graphbolt.FeatureFetcher(subgraph_sampler, fetch_func)
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
dataloader = dgl.graphbolt.SingleProcessDataLoader(device_transferrer)
assert len(list(dataloader)) == N // B
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment