Unverified Commit 50eb1014 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Refactor NeighborSampler and expose fine-grained datapipes. (#6983)

parent e602ab1b
"""Neighbor subgraph samplers for GraphBolt."""
from functools import partial
import torch
from torch.utils.data import functional_datapipe
from ..internal import compact_csc_format, unique_and_compact_csc_formats
from ..minibatch_transformer import MiniBatchTransformer
from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import SampledSubgraphImpl
......@@ -12,8 +15,66 @@ from .sampled_subgraph_impl import SampledSubgraphImpl
__all__ = ["NeighborSampler", "LayerNeighborSampler"]
@functional_datapipe("sample_per_layer")
class SamplePerLayer(MiniBatchTransformer):
"""Sample neighbor edges from a graph for a single layer."""
def __init__(self, datapipe, sampler, fanout, replace, prob_name):
super().__init__(datapipe, self._sample_per_layer)
self.sampler = sampler
self.fanout = fanout
self.replace = replace
self.prob_name = prob_name
def _sample_per_layer(self, minibatch):
subgraph = self.sampler(
minibatch._seed_nodes, self.fanout, self.replace, self.prob_name
)
minibatch.sampled_subgraphs.insert(0, subgraph)
return minibatch
@functional_datapipe("compact_per_layer")
class CompactPerLayer(MiniBatchTransformer):
"""Compact the sampled edges for a single layer."""
def __init__(self, datapipe, deduplicate):
super().__init__(datapipe, self._compact_per_layer)
self.deduplicate = deduplicate
def _compact_per_layer(self, minibatch):
subgraph = minibatch.sampled_subgraphs[0]
seeds = minibatch._seed_nodes
if self.deduplicate:
(
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
compacted_csc_format,
) = compact_csc_format(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
minibatch._seed_nodes = original_row_node_ids
minibatch.sampled_subgraphs[0] = subgraph
return minibatch
@functional_datapipe("sample_neighbor")
class NeighborSampler(SubgraphSampler):
# pylint: disable=abstract-method
"""Sample neighbor edges from a graph and return a subgraph.
Functional name: :obj:`sample_neighbor`.
......@@ -95,6 +156,7 @@ class NeighborSampler(SubgraphSampler):
)]
"""
# pylint: disable=useless-super-delegation
def __init__(
self,
datapipe,
......@@ -103,26 +165,19 @@ class NeighborSampler(SubgraphSampler):
replace=False,
prob_name=None,
deduplicate=True,
sampler=None,
):
super().__init__(datapipe)
self.graph = graph
# Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
self.replace = replace
self.prob_name = prob_name
self.deduplicate = deduplicate
self.sampler = graph.sample_neighbors
if sampler is None:
sampler = graph.sample_neighbors
super().__init__(
datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
)
def sample_subgraphs(self, seeds, seeds_timestamp):
subgraphs = []
num_layers = len(self.fanouts)
def _prepare(self, node_type_to_id, minibatch):
seeds = minibatch._seed_nodes
# Enrich seeds with all node types.
if isinstance(seeds, dict):
ntypes = list(self.graph.node_type_to_id.keys())
ntypes = list(node_type_to_id.keys())
# Loop over different seeds to extract the device they are on.
device = None
dtype = None
......@@ -134,42 +189,37 @@ class NeighborSampler(SubgraphSampler):
seeds = {
ntype: seeds.get(ntype, default_tensor) for ntype in ntypes
}
for hop in range(num_layers):
subgraph = self.sampler(
seeds,
self.fanouts[hop],
self.replace,
self.prob_name,
)
if self.deduplicate:
(
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
minibatch._seed_nodes = seeds
minibatch.sampled_subgraphs = []
return minibatch
@staticmethod
def _set_input_nodes(minibatch):
minibatch.input_nodes = minibatch._seed_nodes
return minibatch
# pylint: disable=arguments-differ
def sampling_stages(
self, datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
):
datapipe = datapipe.transform(
partial(self._prepare, graph.node_type_to_id)
)
else:
(
original_row_node_ids,
compacted_csc_format,
) = compact_csc_format(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
for fanout in reversed(fanouts):
# Convert fanout to tensor.
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
datapipe = datapipe.sample_per_layer(
sampler, fanout, replace, prob_name
)
subgraphs.insert(0, subgraph)
seeds = original_row_node_ids
return seeds, subgraphs
datapipe = datapipe.compact_per_layer(deduplicate)
return datapipe.transform(self._set_input_nodes)
@functional_datapipe("sample_layer_neighbor")
class LayerNeighborSampler(NeighborSampler):
# pylint: disable=abstract-method
"""Sample layer neighbor edges from a graph and return a subgraph.
Functional name: :obj:`sample_layer_neighbor`.
......@@ -280,5 +330,5 @@ class LayerNeighborSampler(NeighborSampler):
replace,
prob_name,
deduplicate,
graph.sample_layer_neighbors,
)
self.sampler = graph.sample_layer_neighbors
......@@ -22,21 +22,44 @@ class SubgraphSampler(MiniBatchTransformer):
Functional name: :obj:`sample_subgraph`.
This class is the base class of all subgraph samplers. Any subclass of
SubgraphSampler should implement the :meth:`sample_subgraphs` method.
SubgraphSampler should implement either the :meth:`sample_subgraphs` method
or the :meth:`sampling_stages` method to define the fine-grained sampling
stages to take advantage of optimizations provided by the GraphBolt
DataLoader.
Parameters
----------
datapipe : DataPipe
The datapipe.
args : Non-Keyword Arguments
Arguments to be passed into sampling_stages.
kwargs : Keyword Arguments
Arguments to be passed into sampling_stages.
"""
def __init__(
self,
datapipe,
*args,
**kwargs,
):
super().__init__(datapipe, self._sample)
datapipe = datapipe.transform(self._preprocess)
datapipe = self.sampling_stages(datapipe, *args, **kwargs)
datapipe = datapipe.transform(self._postprocess)
super().__init__(datapipe, self._identity)
def _sample(self, minibatch):
@staticmethod
def _identity(minibatch):
return minibatch
@staticmethod
def _postprocess(minibatch):
delattr(minibatch, "_seed_nodes")
delattr(minibatch, "_seeds_timestamp")
return minibatch
@staticmethod
def _preprocess(minibatch):
if minibatch.node_pairs is not None:
(
seeds,
......@@ -44,7 +67,7 @@ class SubgraphSampler(MiniBatchTransformer):
minibatch.compacted_node_pairs,
minibatch.compacted_negative_srcs,
minibatch.compacted_negative_dsts,
) = self._node_pairs_preprocess(minibatch)
) = SubgraphSampler._node_pairs_preprocess(minibatch)
elif minibatch.seed_nodes is not None:
seeds = minibatch.seed_nodes
seeds_timestamp = (
......@@ -55,13 +78,12 @@ class SubgraphSampler(MiniBatchTransformer):
f"Invalid minibatch {minibatch}: Either `node_pairs` or "
"`seed_nodes` should have a value."
)
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self.sample_subgraphs(seeds, seeds_timestamp)
minibatch._seed_nodes = seeds
minibatch._seeds_timestamp = seeds_timestamp
return minibatch
def _node_pairs_preprocess(self, minibatch):
@staticmethod
def _node_pairs_preprocess(minibatch):
use_timestamp = hasattr(minibatch, "timestamp")
node_pairs = minibatch.node_pairs
neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts
......@@ -191,6 +213,23 @@ class SubgraphSampler(MiniBatchTransformer):
compacted_negative_dsts if has_neg_dst else None,
)
def _sample(self, minibatch):
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self.sample_subgraphs(
minibatch._seed_nodes, minibatch._seeds_timestamp
)
return minibatch
def sampling_stages(self, datapipe):
"""The sampling stages are defined here by chaining to the datapipe. The
default implementation expects :meth:`sample_subgraphs` to be
implemented. To define fine-grained stages, this method should be
overridden.
"""
return datapipe.transform(self._sample)
def sample_subgraphs(self, seeds, seeds_timestamp):
"""Sample subgraphs from the given seeds, possibly with temporal constraints.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment