Unverified Commit 337b5ea7 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable fanouts to be a list of int (#6309)

parent 89bed21a
"""Neighbor subgraph samplers for GraphBolt."""
import torch
from torch.utils.data import functional_datapipe
from ..subgraph_sampler import SubgraphSampler
......@@ -37,7 +38,7 @@ class NeighborSampler(SubgraphSampler):
The datapipe.
graph : CSCSamplingGraph
The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor]
fanouts: list[torch.Tensor] or list[int]
The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted.
......@@ -81,7 +82,12 @@ class NeighborSampler(SubgraphSampler):
3
"""
super().__init__(datapipe)
self.fanouts = fanouts
# Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.append(fanout)
self.replace = replace
self.prob_name = prob_name
self.sampler = graph.sample_neighbors
......
......@@ -41,6 +41,30 @@ def test_NeighborSampler_invoke(labor):
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_fanouts(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
# `fanouts` is a list of tensors.
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
if labor:
datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
else:
datapipe = item_sampler.sample_neighbor(graph, fanouts)
assert len(list(datapipe)) == 5
# `fanouts` is a list of integers.
fanouts = [2 for _ in range(num_layer)]
if labor:
datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
else:
datapipe = item_sampler.sample_neighbor(graph, fanouts)
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment