Unverified Commit 3fb81fca authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Graphbolt] LayerNeighborSampler and tests (#6214)

parent 44a9faad
"""Neighbor subgraph sampler for GraphBolt."""
"""Neighbor subgraph samplers for GraphBolt."""
from ..subgraph_sampler import SubgraphSampler
from ..utils import unique_and_compact_node_pairs
......@@ -86,13 +86,13 @@ class NeighborSampler(SubgraphSampler):
self.fanouts = fanouts
self.replace = replace
self.prob_name = prob_name
self.graph = graph
self.sampler = graph.sample_neighbors
def _sample_subgraphs(self, seeds):
subgraphs = []
num_layers = len(self.fanouts)
for hop in range(num_layers):
subgraph = self.graph.sample_neighbors(
subgraph = self.sampler(
seeds,
self.fanouts[hop],
self.replace,
......@@ -109,3 +109,89 @@ class NeighborSampler(SubgraphSampler):
)
subgraphs.insert(0, subgraph)
return seeds, subgraphs
class LayerNeighborSampler(NeighborSampler):
"""
Layer-Neighbor sampler is responsible for sampling a subgraph from given
data. It returns an induced subgraph along with compacted information. In
the context of a node classification task, the neighbor sampler directly
utilizes the nodes provided as seed nodes. However, in scenarios involving
link prediction, the process needs another pre-process operation. That is,
gathering unique nodes from the given node pairs, encompassing both
positive and negative node pairs, and employs these nodes as the seed nodes
for subsequent steps.
Implements the approach described in https://arxiv.org/abs/2210.13339,
Appendix A.3. Similar to dgl.dataloading.LaborSampler but this uses
sequential poisson sampling instead of poisson sampling to keep the count
of sampled edges per vertex deterministic.
"""
def __init__(
self,
datapipe,
graph,
fanouts,
replace=False,
prob_name=None,
):
"""
Initlization for a link neighbor subgraph sampler.
Parameters
----------
datapipe : DataPipe
The datapipe.
graph : CSCSamplingGraph
The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor]
The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
prob_name: str, optional
The name of an edge attribute used as the weights of sampling for
each node. This attribute tensor should contain (unnormalized)
probabilities corresponding to each neighboring edge of a node.
It must be a 1D floating-point or boolean tensor, with the number
of elements equalling the total number of edges.
Examples
-------
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def to_link_block(data):
... block = gb.LinkPredictionBlock(node_pair=data)
... return block
...
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs)
>>> minibatch_sampler = gb.MinibatchSampler(
...item_set, batch_size=1,
...)
>>> data_block_converter = Mapper(minibatch_sampler, to_link_block)
>>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.LayerNeighborSampler(
...neg_sampler, graph, fanouts)
>>> for data in subgraph_sampler:
... print(data.compacted_node_pair)
... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
3
(tensor([0, 0, 0]), tensor([1, 1, 1]))
3
"""
super().__init__(datapipe, graph, fanouts, replace, prob_name)
self.sampler = graph.sample_layer_neighbors
......@@ -11,14 +11,16 @@ def to_node_block(data):
return block
def test_SubgraphSampler_Node():
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10))
minibatch_dp = gb.MinibatchSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(data_block_converter, graph, fanouts)
assert len(list(sampler_dp)) == 5
......@@ -27,7 +29,8 @@ def to_link_block(data):
return block
def test_SubgraphSampler_Link():
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(
(
......@@ -39,7 +42,8 @@ def test_SubgraphSampler_Link():
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block)
neighbor_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(data_block_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5
......@@ -52,7 +56,8 @@ def test_SubgraphSampler_Link():
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
],
)
def test_SubgraphSampler_Link_With_Negative(format):
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_With_Negative(format, labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(
(
......@@ -67,7 +72,8 @@ def test_SubgraphSampler_Link_With_Negative(format):
negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph
)
neighbor_dp = gb.NeighborSampler(negative_dp, graph, fanouts)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5
......@@ -93,7 +99,8 @@ def get_hetero_graph():
)
def test_SubgraphSampler_Link_Hetero():
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero(labor):
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{
......@@ -116,7 +123,8 @@ def test_SubgraphSampler_Link_Hetero():
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(minibatch_dp, to_link_block)
neighbor_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(data_block_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5
......@@ -129,7 +137,8 @@ def test_SubgraphSampler_Link_Hetero():
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
],
)
def test_SubgraphSampler_Link_Hetero_With_Negative(format):
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{
......@@ -155,5 +164,6 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format):
negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph
)
neighbor_dp = gb.NeighborSampler(negative_dp, graph, fanouts)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment