"git@developer.sourcefind.cn:OpenDAS/nerfacc.git" did not exist on "15e4430fae9ccabd94ce36cd12a344fca20068a3"
Unverified Commit 337b5ea7 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable fanouts to be a list of int (#6309)

parent 89bed21a
"""Neighbor subgraph samplers for GraphBolt.""" """Neighbor subgraph samplers for GraphBolt."""
import torch
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from ..subgraph_sampler import SubgraphSampler from ..subgraph_sampler import SubgraphSampler
...@@ -37,7 +38,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -37,7 +38,7 @@ class NeighborSampler(SubgraphSampler):
The datapipe. The datapipe.
graph : CSCSamplingGraph graph : CSCSamplingGraph
The graph on which to perform subgraph sampling. The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor] fanouts: list[torch.Tensor] or list[int]
The number of edges to be sampled for each node with or without The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted. signifies the layer of sampling being conducted.
...@@ -81,7 +82,12 @@ class NeighborSampler(SubgraphSampler): ...@@ -81,7 +82,12 @@ class NeighborSampler(SubgraphSampler):
3 3
""" """
super().__init__(datapipe) super().__init__(datapipe)
self.fanouts = fanouts # Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.append(fanout)
self.replace = replace self.replace = replace
self.prob_name = prob_name self.prob_name = prob_name
self.sampler = graph.sample_neighbors self.sampler = graph.sample_neighbors
......
...@@ -41,6 +41,30 @@ def test_NeighborSampler_invoke(labor): ...@@ -41,6 +41,30 @@ def test_NeighborSampler_invoke(labor):
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_fanouts(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
# `fanouts` is a list of tensors.
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
if labor:
datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
else:
datapipe = item_sampler.sample_neighbor(graph, fanouts)
assert len(list(datapipe)) == 5
# `fanouts` is a list of integers.
fanouts = [2 for _ in range(num_layer)]
if labor:
datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
else:
datapipe = item_sampler.sample_neighbor(graph, fanouts)
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor): def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment