Unverified Commit ec495239 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable to invoke gb samplers in functional form (#6297)

parent 79a95477
"""Base types and utilities for Graph Bolt."""
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from ..utils import recursive_apply
......@@ -53,6 +54,7 @@ def _to(x, device):
return x.to(device) if hasattr(x, "to") else x
@functional_datapipe("copy_to")
class CopyTo(IterDataPipe):
"""DataPipe that transfers each element yielded from the previous DataPipe
to the given device.
......
......@@ -2,9 +2,12 @@
from typing import Dict
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper
@functional_datapipe("fetch_feature")
class FeatureFetcher(Mapper):
"""A feature fetcher used to fetch features for node/edge in graphbolt."""
......
"""Neighbor subgraph samplers for GraphBolt."""
from torch.utils.data import functional_datapipe
from ..subgraph_sampler import SubgraphSampler
from ..utils import unique_and_compact_node_pairs
from .sampled_subgraph_impl import SampledSubgraphImpl
@functional_datapipe("sample_neighbor")
class NeighborSampler(SubgraphSampler):
"""
Neighbor sampler is responsible for sampling a subgraph from given data. It
......@@ -106,6 +109,7 @@ class NeighborSampler(SubgraphSampler):
return seeds, subgraphs
@functional_datapipe("sample_layer_neighbor")
class LayerNeighborSampler(NeighborSampler):
"""
Layer-Neighbor sampler is responsible for sampling a subgraph from given
......
"""Uniform negative sampler for GraphBolt."""
from torch.utils.data import functional_datapipe
from ..negative_sampler import NegativeSampler
@functional_datapipe("sample_uniform_negative")
class UniformNegativeSampler(NegativeSampler):
"""
Negative samplers randomly select negative destination nodes for each
......
......@@ -4,7 +4,7 @@ from collections.abc import Mapping
from functools import partial
from typing import Callable, Iterator, Optional
from torch.utils.data import default_collate
from torch.utils.data import default_collate, functional_datapipe
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from ..base import dgl_warning
......@@ -78,6 +78,7 @@ def minibatcher_default(batch, names):
return minibatch
@functional_datapipe("sample_item")
class ItemSampler(IterDataPipe):
"""Item Sampler.
......
......@@ -3,11 +3,13 @@
from _collections_abc import Mapping
import torch
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper
from .data_format import LinkPredictionEdgeFormat
@functional_datapipe("sample_negative")
class NegativeSampler(Mapper):
"""
A negative sampler used to generate negative samples and return
......
......@@ -3,12 +3,14 @@
from collections import defaultdict
from typing import Dict
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper
from .base import etype_str_to_tuple
from .utils import unique_and_compact
@functional_datapipe("sample_subgraph")
class SubgraphSampler(Mapper):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
from a larger graph."""
......
......@@ -5,6 +5,75 @@ import torch
from torchdata.datapipes.iter import Mapper
def test_NegativeSampler_invoke():
# Instantiate graph and required datapipes.
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
negative_ratio = 2
# Invoke NegativeSampler via class constructor.
negative_sampler = gb.NegativeSampler(
item_sampler,
negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))
# Invoke NegativeSampler via functional form.
negative_sampler = item_sampler.sample_negative(
negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))
def test_UniformNegativeSampler_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
negative_ratio = 2
# Verify iteration over UniformNegativeSampler.
def _verify(negative_sampler):
for data in negative_sampler:
src, dst = data.node_pairs
labels = data.labels
# Assertation
assert len(src) == batch_size * (negative_ratio + 1)
assert len(dst) == batch_size * (negative_ratio + 1)
assert len(labels) == batch_size * (negative_ratio + 1)
assert torch.all(torch.eq(labels[:batch_size], 1))
assert torch.all(torch.eq(labels[batch_size:], 0))
# Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph,
)
_verify(negative_sampler)
# Invoke UniformNegativeSampler via functional form.
negative_sampler = item_sampler.sample_uniform_negative(
negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph,
)
_verify(negative_sampler)
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_NegativeSampler_Independent_Format(negative_ratio):
# Construct CSCSamplingGraph.
......
......@@ -11,8 +11,14 @@ import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
dp = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4)
# Invoke CopyTo via class constructor.
dp = gb.CopyTo(dp, "cuda")
for data in dp:
assert data.device.type == "cuda"
# Invoke CopyTo via functional form.
dp = dp.copy_to("cuda")
for data in dp:
assert data.device.type == "cuda"
......
......@@ -4,6 +4,35 @@ import torch
from torchdata.datapipes.iter import Mapper
def test_FeatureFetcher_invoke():
# Prepare graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(20, 0.15)
a = torch.randint(0, 10, (graph.num_nodes,))
b = torch.randint(0, 10, (graph.num_edges,))
features = {}
keys = [("node", None, "a"), ("edge", None, "b")]
features[keys[0]] = gb.TorchBasedFeature(a)
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
# Invoke FeatureFetcher via class constructor.
datapipe = gb.NeighborSampler(datapipe, graph, fanouts)
datapipe = gb.FeatureFetcher(datapipe, feature_store, ["a"], ["b"])
assert len(list(datapipe)) == 5
# Invoke FeatureFetcher via functional form.
datapipe = datapipe.sample_neighbor(graph, fanouts).fetch_feature(
feature_store, ["a"], ["b"]
)
assert len(list(datapipe)) == 5
def test_FeatureFetcher_homo():
graph = gb_test_utils.rand_csc_graph(20, 0.15)
a = torch.randint(0, 10, (graph.num_nodes,))
......
......@@ -6,6 +6,42 @@ import torchdata.datapipes as dp
from torchdata.datapipes.iter import Mapper
def test_SubgraphSampler_invoke():
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2)
# Invoke via class constructor.
datapipe = gb.SubgraphSampler(datapipe)
with pytest.raises(NotImplementedError):
next(iter(datapipe))
# Invokde via functional form.
datapipe = datapipe.sample_subgraph()
with pytest.raises(NotImplementedError):
next(iter(datapipe))
@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_invoke(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
# Invoke via class constructor.
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(datapipe, graph, fanouts)
assert len(list(datapipe)) == 5
# Invokde via functional form.
if labor:
datapipe = datapipe.sample_layer_neighbor(graph, fanouts)
else:
datapipe = datapipe.sample_neighbor(graph, fanouts)
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment