Unverified Commit ac53c1fa authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[GraphBolt] Subgraph sampler base datapipe (#5904)


Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 6ba6dd60
...@@ -9,6 +9,7 @@ from .graph_storage import * ...@@ -9,6 +9,7 @@ from .graph_storage import *
from .itemset import * from .itemset import *
from .minibatch_sampler import * from .minibatch_sampler import *
from .feature_store import * from .feature_store import *
from .subgraph_sampler import *
def load_graphbolt(): def load_graphbolt():
......
"""Subgraph samplers"""
from torchdata.datapipes.iter import Mapper
class SubgraphSampler(Mapper):
"""A subgraph sampler.
It is an iterator equivalent to the following:
.. code:: python
for data in datapipe:
yield sampler_func(data)
Parameters
----------
datapipe : DataPipe
The datapipe.
fn : callable
The subgraph sampling function.
"""
import dgl.graphbolt
import scipy.sparse as sp
import torch
def rand_csc_graph(N, density):
adj = sp.random(N, N, density)
adj = adj + adj.T
adj = adj.tocsc()
indptr = torch.LongTensor(adj.indptr)
indices = torch.LongTensor(adj.indices)
graph = dgl.graphbolt.from_csc(indptr, indices)
return graph
import dgl
import dgl.graphbolt
import gb_test_utils
import pytest
import torch
import torchdata.datapipes as dp
def get_graphbolt_sampler_func():
graph = gb_test_utils.rand_csc_graph(20, 0.15)
def sampler_func(data):
adjs = []
seeds = data
for hop in range(2):
sg = graph.sample_neighbors(seeds, torch.LongTensor([2]))
seeds = sg.indices
adjs.insert(0, sg)
return seeds, data, adjs
return sampler_func
def get_dgl_sampler_func():
graph = dgl.add_reverse_edges(dgl.rand_graph(20, 60))
sampler = dgl.dataloading.NeighborSampler([2, 2])
def sampler_func(data):
return sampler.sample(graph, data)
return sampler_func
def get_graphbolt_minibatch_dp():
itemset = dgl.graphbolt.ItemSet(torch.arange(10))
return dgl.graphbolt.MinibatchSampler(itemset, batch_size=2)
def get_torchdata_minibatch_dp():
minibatch_dp = dp.map.SequenceWrapper(torch.arange(10)).batch(2)
minibatch_dp = minibatch_dp.to_iter_datapipe().collate()
return minibatch_dp
@pytest.mark.parametrize(
"sampler_func", [get_graphbolt_sampler_func(), get_dgl_sampler_func()]
)
@pytest.mark.parametrize(
"minibatch_dp", [get_graphbolt_minibatch_dp(), get_torchdata_minibatch_dp()]
)
def test_SubgraphSampler(minibatch_dp, sampler_func):
sampler_dp = dgl.graphbolt.SubgraphSampler(minibatch_dp, sampler_func)
assert len(list(sampler_dp)) == 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment