Unverified Commit ec495239 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable to invoke gb samplers in functional form (#6297)

parent 79a95477
"""Base types and utilities for Graph Bolt.""" """Base types and utilities for Graph Bolt."""
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
from ..utils import recursive_apply from ..utils import recursive_apply
...@@ -53,6 +54,7 @@ def _to(x, device): ...@@ -53,6 +54,7 @@ def _to(x, device):
return x.to(device) if hasattr(x, "to") else x return x.to(device) if hasattr(x, "to") else x
@functional_datapipe("copy_to")
class CopyTo(IterDataPipe): class CopyTo(IterDataPipe):
"""DataPipe that transfers each element yielded from the previous DataPipe """DataPipe that transfers each element yielded from the previous DataPipe
to the given device. to the given device.
......
...@@ -2,9 +2,12 @@ ...@@ -2,9 +2,12 @@
from typing import Dict from typing import Dict
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
@functional_datapipe("fetch_feature")
class FeatureFetcher(Mapper): class FeatureFetcher(Mapper):
"""A feature fetcher used to fetch features for node/edge in graphbolt.""" """A feature fetcher used to fetch features for node/edge in graphbolt."""
......
"""Neighbor subgraph samplers for GraphBolt.""" """Neighbor subgraph samplers for GraphBolt."""
from torch.utils.data import functional_datapipe
from ..subgraph_sampler import SubgraphSampler from ..subgraph_sampler import SubgraphSampler
from ..utils import unique_and_compact_node_pairs from ..utils import unique_and_compact_node_pairs
from .sampled_subgraph_impl import SampledSubgraphImpl from .sampled_subgraph_impl import SampledSubgraphImpl
@functional_datapipe("sample_neighbor")
class NeighborSampler(SubgraphSampler): class NeighborSampler(SubgraphSampler):
""" """
Neighbor sampler is responsible for sampling a subgraph from given data. It Neighbor sampler is responsible for sampling a subgraph from given data. It
...@@ -106,6 +109,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -106,6 +109,7 @@ class NeighborSampler(SubgraphSampler):
return seeds, subgraphs return seeds, subgraphs
@functional_datapipe("sample_layer_neighbor")
class LayerNeighborSampler(NeighborSampler): class LayerNeighborSampler(NeighborSampler):
""" """
Layer-Neighbor sampler is responsible for sampling a subgraph from given Layer-Neighbor sampler is responsible for sampling a subgraph from given
......
"""Uniform negative sampler for GraphBolt.""" """Uniform negative sampler for GraphBolt."""
from torch.utils.data import functional_datapipe
from ..negative_sampler import NegativeSampler from ..negative_sampler import NegativeSampler
@functional_datapipe("sample_uniform_negative")
class UniformNegativeSampler(NegativeSampler): class UniformNegativeSampler(NegativeSampler):
""" """
Negative samplers randomly select negative destination nodes for each Negative samplers randomly select negative destination nodes for each
......
...@@ -4,7 +4,7 @@ from collections.abc import Mapping ...@@ -4,7 +4,7 @@ from collections.abc import Mapping
from functools import partial from functools import partial
from typing import Callable, Iterator, Optional from typing import Callable, Iterator, Optional
from torch.utils.data import default_collate from torch.utils.data import default_collate, functional_datapipe
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from ..base import dgl_warning from ..base import dgl_warning
...@@ -78,6 +78,7 @@ def minibatcher_default(batch, names): ...@@ -78,6 +78,7 @@ def minibatcher_default(batch, names):
return minibatch return minibatch
@functional_datapipe("sample_item")
class ItemSampler(IterDataPipe): class ItemSampler(IterDataPipe):
"""Item Sampler. """Item Sampler.
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
from _collections_abc import Mapping from _collections_abc import Mapping
import torch import torch
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
from .data_format import LinkPredictionEdgeFormat from .data_format import LinkPredictionEdgeFormat
@functional_datapipe("sample_negative")
class NegativeSampler(Mapper): class NegativeSampler(Mapper):
""" """
A negative sampler used to generate negative samples and return A negative sampler used to generate negative samples and return
......
...@@ -3,12 +3,14 @@ ...@@ -3,12 +3,14 @@
from collections import defaultdict from collections import defaultdict
from typing import Dict from typing import Dict
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
from .base import etype_str_to_tuple from .base import etype_str_to_tuple
from .utils import unique_and_compact from .utils import unique_and_compact
@functional_datapipe("sample_subgraph")
class SubgraphSampler(Mapper): class SubgraphSampler(Mapper):
"""A subgraph sampler used to sample a subgraph from a given set of nodes """A subgraph sampler used to sample a subgraph from a given set of nodes
from a larger graph.""" from a larger graph."""
......
...@@ -5,6 +5,75 @@ import torch ...@@ -5,6 +5,75 @@ import torch
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
def test_NegativeSampler_invoke():
# Instantiate graph and required datapipes.
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
negative_ratio = 2
# Invoke NegativeSampler via class constructor.
negative_sampler = gb.NegativeSampler(
item_sampler,
negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))
# Invoke NegativeSampler via functional form.
negative_sampler = item_sampler.sample_negative(
negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))
def test_UniformNegativeSampler_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
negative_ratio = 2
# Verify iteration over UniformNegativeSampler.
def _verify(negative_sampler):
for data in negative_sampler:
src, dst = data.node_pairs
labels = data.labels
# Assertation
assert len(src) == batch_size * (negative_ratio + 1)
assert len(dst) == batch_size * (negative_ratio + 1)
assert len(labels) == batch_size * (negative_ratio + 1)
assert torch.all(torch.eq(labels[:batch_size], 1))
assert torch.all(torch.eq(labels[batch_size:], 0))
# Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph,
)
_verify(negative_sampler)
# Invoke UniformNegativeSampler via functional form.
negative_sampler = item_sampler.sample_uniform_negative(
negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph,
)
_verify(negative_sampler)
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_NegativeSampler_Independent_Format(negative_ratio): def test_NegativeSampler_Independent_Format(negative_ratio):
# Construct CSCSamplingGraph. # Construct CSCSamplingGraph.
......
...@@ -11,8 +11,14 @@ import torch ...@@ -11,8 +11,14 @@ import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") @unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo(): def test_CopyTo():
dp = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4) dp = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4)
# Invoke CopyTo via class constructor.
dp = gb.CopyTo(dp, "cuda") dp = gb.CopyTo(dp, "cuda")
for data in dp:
assert data.device.type == "cuda"
# Invoke CopyTo via functional form.
dp = dp.copy_to("cuda")
for data in dp: for data in dp:
assert data.device.type == "cuda" assert data.device.type == "cuda"
......
...@@ -4,6 +4,35 @@ import torch ...@@ -4,6 +4,35 @@ import torch
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
def test_FeatureFetcher_invoke():
# Prepare graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(20, 0.15)
a = torch.randint(0, 10, (graph.num_nodes,))
b = torch.randint(0, 10, (graph.num_edges,))
features = {}
keys = [("node", None, "a"), ("edge", None, "b")]
features[keys[0]] = gb.TorchBasedFeature(a)
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
# Invoke FeatureFetcher via class constructor.
datapipe = gb.NeighborSampler(datapipe, graph, fanouts)
datapipe = gb.FeatureFetcher(datapipe, feature_store, ["a"], ["b"])
assert len(list(datapipe)) == 5
# Invoke FeatureFetcher via functional form.
datapipe = datapipe.sample_neighbor(graph, fanouts).fetch_feature(
feature_store, ["a"], ["b"]
)
assert len(list(datapipe)) == 5
def test_FeatureFetcher_homo(): def test_FeatureFetcher_homo():
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15)
a = torch.randint(0, 10, (graph.num_nodes,)) a = torch.randint(0, 10, (graph.num_nodes,))
......
...@@ -6,6 +6,42 @@ import torchdata.datapipes as dp ...@@ -6,6 +6,42 @@ import torchdata.datapipes as dp
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
def test_SubgraphSampler_invoke():
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2)
# Invoke via class constructor.
datapipe = gb.SubgraphSampler(datapipe)
with pytest.raises(NotImplementedError):
next(iter(datapipe))
# Invokde via functional form.
datapipe = datapipe.sample_subgraph()
with pytest.raises(NotImplementedError):
next(iter(datapipe))
@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_invoke(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
datapipe = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
# Invoke via class constructor.
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(datapipe, graph, fanouts)
assert len(list(datapipe)) == 5
# Invokde via functional form.
if labor:
datapipe = datapipe.sample_layer_neighbor(graph, fanouts)
else:
datapipe = datapipe.sample_neighbor(graph, fanouts)
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor): def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment