Unverified Commit f6dec359 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt]Subgraph sampler udf (#6129)

parent 4663cb0c
...@@ -5,3 +5,4 @@ from .torch_based_feature_store import * ...@@ -5,3 +5,4 @@ from .torch_based_feature_store import *
from .csc_sampling_graph import * from .csc_sampling_graph import *
from .sampled_subgraph_impl import * from .sampled_subgraph_impl import *
from .uniform_negative_sampler import * from .uniform_negative_sampler import *
from .neighbor_sampler import *
"""Neighbor subgraph sampler for GraphBolt."""
from ..subgraph_sampler import SubgraphSampler
from ..utils import unique_and_compact_node_pairs
from .sampled_subgraph_impl import SampledSubgraphImpl
class NeighborSampler(SubgraphSampler):
"""
Neighbor sampler is responsible for sampling a subgraph from given data. It
returns an induced subgraph along with compacted information. In the
context of a node classification task, the neighbor sampler directly
utilizes the nodes provided as seed nodes. However, in scenarios involving
link prediction, the process needs another pre-peocess operation. That is,
gathering unique nodes from the given node pairs, encompassing both
positive and negative node pairs, and employs these nodes as the seed nodes
for subsequent steps.
"""
def __init__(
self,
datapipe,
graph,
fanouts,
replace=False,
prob_name=None,
):
"""
Initlization for a link neighbor subgraph sampler.
Parameters
----------
datapipe : DataPipe
The datapipe.
graph : CSCSamplingGraph
The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor]
The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
prob_name: str, optional
The name of an edge attribute used as the weights of sampling for
each node. This attribute tensor should contain (unnormalized)
probabilities corresponding to each neighboring edge of a node.
It must be a 1D floating-point or boolean tensor, with the number
of elements equalling the total number of edges.
Examples
-------
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def to_link_block(data):
... block = gb.LinkPredictionBlock(node_pair=data)
... return block
...
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
...item_set, batch_size=1,
...)
>>> data_block_converter = Mapper(minibatch_sampler, to_link_block)
>>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.NeighborSampler(
...neg_sampler, graph, fanouts)
>>> for data in subgraph_sampler:
... print(data.compacted_node_pair)
... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
3
(tensor([0, 0, 0]), tensor([1, 1, 1]))
3
"""
super().__init__(datapipe)
self.fanouts = fanouts
self.replace = replace
self.prob_name = prob_name
self.graph = graph
def _sample_sub_graphs(self, seeds):
subgraphs = []
num_layers = len(self.fanouts)
for hop in range(num_layers):
subgraph = self.graph.sample_neighbors(
seeds,
self.fanouts[hop],
self.replace,
self.prob_name,
)
reverse_row_node_ids = seeds
seeds, compacted_node_pairs = unique_and_compact_node_pairs(
subgraph.node_pairs, seeds
)
subgraph = SampledSubgraphImpl(
node_pairs=compacted_node_pairs,
reverse_column_node_ids=seeds,
reverse_row_node_ids=reverse_row_node_ids,
)
subgraphs.insert(0, subgraph)
return seeds, subgraphs
"""Subgraph samplers""" """Subgraph samplers"""
from torchdata.datapipes.iter import Mapper from collections import defaultdict
from typing import Dict
from torchdata.datapipes.iter import Mapper
class SubgraphSampler(Mapper): from .link_prediction_block import LinkPredictionBlock
"""A subgraph sampler. from .node_classification_block import NodeClassificationBlock
from .utils import unique_and_compact
It is an iterator equivalent to the following:
.. code:: python class SubgraphSampler(Mapper):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
from a larger graph."""
for data in datapipe: def __init__(
yield sampler_func(data) self,
datapipe,
):
"""
Initlization for a subgraph sampler.
Parameters Parameters
---------- ----------
datapipe : DataPipe datapipe : DataPipe
The datapipe. The datapipe.
fn : callable
The subgraph sampling function.
""" """
super().__init__(datapipe, self._sample)
def _sample(self, data):
if isinstance(data, LinkPredictionBlock):
(
seeds,
data.compacted_node_pair,
data.compacted_negative_head,
data.compacted_negative_tail,
) = self._link_prediction_preprocess(data)
elif isinstance(data, NodeClassificationBlock):
seeds = data.seed_node
else:
raise TypeError(f"Unsupported type of data {data}.")
data.input_nodes, data.sampled_subgraphs = self._sample_sub_graphs(
seeds
)
return data
def _link_prediction_preprocess(self, data):
node_pair = data.node_pair
neg_src, neg_dst = data.negative_head, data.negative_tail
has_neg_src = neg_src is not None
has_neg_dst = neg_dst is not None
is_heterogeneous = isinstance(node_pair, Dict)
if is_heterogeneous:
# Collect nodes from all types of input.
nodes = defaultdict(list)
for (src_type, _, dst_type), (src, dst) in node_pair.items():
nodes[src_type].append(src)
nodes[dst_type].append(dst)
if has_neg_src:
for (src_type, _, _), src in neg_src.items():
nodes[src_type].append(src.view(-1))
if has_neg_dst:
for (_, _, dst_type), dst in neg_dst.items():
nodes[dst_type].append(dst.view(-1))
# Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes)
(
compacted_node_pair,
compacted_negative_head,
compacted_negative_tail,
) = ({}, {}, {})
# Map back in same order as collect.
for etype, _ in node_pair.items():
src_type, _, dst_type = etype
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_node_pair[etype] = (src, dst)
if has_neg_src:
for etype, _ in neg_src.items():
compacted_negative_head[etype] = compacted[etype[0]].pop(0)
if has_neg_dst:
for etype, _ in neg_dst.items():
compacted_negative_tail[etype] = compacted[etype[2]].pop(0)
else:
# Collect nodes from all types of input.
nodes = list(node_pair)
if has_neg_src:
nodes.append(neg_src.view(-1))
if has_neg_dst:
nodes.append(neg_dst.view(-1))
# Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes)
# Map back in same order as collect.
compacted_node_pair = tuple(compacted[:2])
compacted = compacted[2:]
if has_neg_src:
compacted_negative_head = compacted.pop(0)
if has_neg_dst:
compacted_negative_tail = compacted.pop(0)
return (
seeds,
compacted_node_pair,
compacted_negative_head if has_neg_src else None,
compacted_negative_tail if has_neg_dst else None,
)
def _sample_sub_graphs(self, seeds):
raise NotImplementedError
import os
import unittest
from functools import partial from functools import partial
import backend as F import backend as F
...@@ -5,12 +7,17 @@ import dgl ...@@ -5,12 +7,17 @@ import dgl
import dgl.graphbolt import dgl.graphbolt
import gb_test_utils import gb_test_utils
import torch import torch
from torchdata.datapipes.iter import Mapper
def sampler_func(graph, data): def to_node_block(data):
seeds = data block = dgl.graphbolt.NodeClassificationBlock(seed_node=data)
sampler = dgl.dataloading.NeighborSampler([2, 2]) return block
return sampler.sample(graph, seeds)
def to_tuple(data):
output_nodes = data.sampled_subgraphs[-1].reverse_column_node_ids
return data.input_nodes, output_nodes, data.sampled_subgraphs
def fetch_func(features, labels, data): def fetch_func(features, labels, data):
...@@ -20,23 +27,26 @@ def fetch_func(features, labels, data): ...@@ -20,23 +27,26 @@ def fetch_func(features, labels, data):
return input_features, output_labels, adjs return input_features, output_labels, adjs
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
# TODO (peizhou): Will enable windows test once CSCSamplingraph is pickleable.
def test_DataLoader(): def test_DataLoader():
N = 40 N = 40
B = 4 B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N)) itemset = dgl.graphbolt.ItemSet(torch.arange(N))
# TODO(BarclayII): temporarily using DGLGraph. Should test using graph = gb_test_utils.rand_csc_graph(200, 0.15)
# GraphBolt's storage as well once issue #5953 is resolved.
graph = dgl.add_reverse_edges(dgl.rand_graph(200, 6000))
features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4)) features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,))) labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,)))
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B) minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B)
subgraph_sampler = dgl.graphbolt.SubgraphSampler( block_converter = Mapper(minibatch_sampler, to_node_block)
minibatch_sampler, subgraph_sampler = dgl.graphbolt.NeighborSampler(
partial(sampler_func, graph), block_converter,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
) )
tuple_converter = Mapper(subgraph_sampler, to_tuple)
feature_fetcher = dgl.graphbolt.FeatureFetcher( feature_fetcher = dgl.graphbolt.FeatureFetcher(
subgraph_sampler, tuple_converter,
partial(fetch_func, features, labels), partial(fetch_func, features, labels),
) )
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx()) device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
......
...@@ -3,6 +3,17 @@ import dgl ...@@ -3,6 +3,17 @@ import dgl
import dgl.graphbolt import dgl.graphbolt
import gb_test_utils import gb_test_utils
import torch import torch
from torchdata.datapipes.iter import Mapper
def to_node_block(data):
block = dgl.graphbolt.NodeClassificationBlock(seed_node=data)
return block
def to_tuple(data):
output_nodes = data.sampled_subgraphs[-1].reverse_column_node_ids
return data.input_nodes, output_nodes, data.sampled_subgraphs
def test_DataLoader(): def test_DataLoader():
...@@ -13,19 +24,6 @@ def test_DataLoader(): ...@@ -13,19 +24,6 @@ def test_DataLoader():
features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4)) features = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,))) labels = dgl.graphbolt.TorchBasedFeature(torch.randint(0, 10, (200,)))
def sampler_func(data):
adjs = []
seeds = data
for hop in range(2):
sg = graph.sample_neighbors(seeds, torch.LongTensor([2]))
seeds = sg.node_pairs[0]
adjs.insert(0, sg)
input_nodes = seeds
output_nodes = data
return input_nodes, output_nodes, adjs
def fetch_func(data): def fetch_func(data):
input_nodes, output_nodes, adjs = data input_nodes, output_nodes, adjs = data
input_features = features.read(input_nodes) input_features = features.read(input_nodes)
...@@ -33,11 +31,14 @@ def test_DataLoader(): ...@@ -33,11 +31,14 @@ def test_DataLoader():
return input_features, output_labels, adjs return input_features, output_labels, adjs
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B) minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B)
subgraph_sampler = dgl.graphbolt.SubgraphSampler( block_converter = Mapper(minibatch_sampler, to_node_block)
minibatch_sampler, subgraph_sampler = dgl.graphbolt.NeighborSampler(
sampler_func, block_converter,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
) )
feature_fetcher = dgl.graphbolt.FeatureFetcher(subgraph_sampler, fetch_func) tuple_converter = Mapper(subgraph_sampler, to_tuple)
feature_fetcher = dgl.graphbolt.FeatureFetcher(tuple_converter, fetch_func)
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx()) device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
dataloader = dgl.graphbolt.SingleProcessDataLoader(device_transferrer) dataloader = dgl.graphbolt.SingleProcessDataLoader(device_transferrer)
......
import dgl import dgl.graphbolt as gb
import dgl.graphbolt
import gb_test_utils import gb_test_utils
import pytest import pytest
import torch import torch
import torchdata.datapipes as dp import torchdata.datapipes as dp
from torchdata.datapipes.iter import Mapper
def get_graphbolt_sampler_func(): def to_node_block(data):
block = gb.NodeClassificationBlock(seed_node=data)
return block
def test_SubgraphSampler_Node():
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
assert len(list(sampler_dp)) == 5
def sampler_func(data):
adjs = []
seeds = data
for hop in range(2): def to_link_block(data):
sg = graph.sample_neighbors(seeds, torch.LongTensor([2])) block = gb.LinkPredictionBlock(node_pair=data)
seeds = sg.node_pairs[0] return block
adjs.insert(0, sg)
return seeds, data, adjs
return sampler_func
def test_SubgraphSampler_Link():
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(
(
torch.arange(0, 10),
torch.arange(10, 20),
)
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block)
neighbor_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5
def get_dgl_sampler_func():
graph = dgl.add_reverse_edges(dgl.rand_graph(20, 60))
sampler = dgl.dataloading.NeighborSampler([2, 2])
def sampler_func(data): @pytest.mark.parametrize(
return sampler.sample(graph, data) "format",
[
gb.LinkPredictionEdgeFormat.INDEPENDENT,
gb.LinkPredictionEdgeFormat.CONDITIONED,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
],
)
def test_SubgraphSampler_Link_With_Negative(format):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(
(
torch.arange(0, 10),
torch.arange(10, 20),
)
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block)
negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph
)
neighbor_dp = gb.NeighborSampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5
return sampler_func
def get_hetero_graph():
# COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
return gb.from_csc(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
)
def get_graphbolt_minibatch_dp():
itemset = dgl.graphbolt.ItemSet(torch.arange(10))
return dgl.graphbolt.MinibatchSampler(itemset, batch_size=2)
def test_SubgraphSampler_Link_Hetero():
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{
("n1", "e1", "n2"): gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]),
)
),
("n2", "e2", "n1"): gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]),
)
),
}
)
def get_torchdata_minibatch_dp(): minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
minibatch_dp = dp.map.SequenceWrapper(torch.arange(10)).batch(2) num_layer = 2
minibatch_dp = minibatch_dp.to_iter_datapipe().collate() fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
return minibatch_dp data_block_converter = Mapper(minibatch_dp, to_link_block)
neighbor_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sampler_func", [get_graphbolt_sampler_func(), get_dgl_sampler_func()] "format",
) [
@pytest.mark.parametrize( gb.LinkPredictionEdgeFormat.INDEPENDENT,
"minibatch_dp", [get_graphbolt_minibatch_dp(), get_torchdata_minibatch_dp()] gb.LinkPredictionEdgeFormat.CONDITIONED,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
],
) )
def test_SubgraphSampler(minibatch_dp, sampler_func): def test_SubgraphSampler_Link_Hetero_With_Negative(format):
sampler_dp = dgl.graphbolt.SubgraphSampler(minibatch_dp, sampler_func) graph = get_hetero_graph()
assert len(list(sampler_dp)) == 5 itemset = gb.ItemSetDict(
{
("n1", "e1", "n2"): gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 0, 1]),
)
),
("n2", "e2", "n1"): gb.ItemSet(
(
torch.LongTensor([0, 0, 1, 1, 2, 2]),
torch.LongTensor([0, 1, 1, 0, 0, 1]),
)
),
}
)
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block)
negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph
)
neighbor_dp = gb.NeighborSampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment